import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
from typing import List, Dict
class Ad:
def __init__(self,
str,
ad_id: float,
true_ctr: float = 0.0):
decay_rate: """
Initialize an ad with its true CTR and optional decay rate.
Args:
ad_id: Unique identifier for the ad
true_ctr: True click-through rate
decay_rate: Daily decay rate of CTR (simulates ad fatigue)
"""
self.ad_id = ad_id
self.base_ctr = true_ctr
self.decay_rate = decay_rate
self.start_time = datetime.now()
def get_click(self, current_time: datetime) -> bool:
"""Simulate if a user clicks based on current CTR."""
= (current_time - self.start_time).days
days_active = self.base_ctr * (1 - self.decay_rate) ** days_active
current_ctr return np.random.random() < current_ctr
# Create sample ads with different CTRs and decay rates
= [
ads "Summer Sale", 0.08, 0.1), # High initial CTR, fast decay
Ad("Product Demo", 0.05, 0.05), # Medium CTR, medium decay
Ad("Brand Story", 0.03, 0.02), # Low CTR, slow decay
Ad("Limited Offer", 0.07, 0.15), # High CTR, very fast decay
Ad("Newsletter", 0.04, 0.01) # Medium-low CTR, very slow decay
Ad( ]
Introduction
Imagine you’re running a website with multiple ad variants. Which ad should you show to maximize clicks? This is a perfect application for the multi-armed bandit (MAB) algorithm. In this post, we’ll build a practical ad selection system using MAB, complete with:
- Simulated user behavior
- Click-through rate (CTR) optimization
- Real-time learning and adaptation
- Performance visualization
The Ad Selection Problem
When a user visits your website, you have milliseconds to decide which ad to show. Each ad has an unknown click-through rate, and your goal is to:
- Find the best-performing ads (exploration)
- Show the best ads more frequently (exploitation)
- Adapt to changing user preferences over time
Let’s implement this scenario:
Implementing the Bandit Algorithm
We’ll use a Thompson Sampling approach, which is well-suited for online advertising because it:
- Balances exploration and exploitation naturally
- Handles uncertainty well
- Adapts quickly to changing conditions
class AdBandit:
def __init__(self, ad_ids: List[str]):
"""
Initialize the bandit algorithm with Beta distributions for each ad.
Args:
ad_ids: List of ad identifiers
"""
self.ad_ids = ad_ids
# Beta distribution parameters for each ad
self.alphas = {ad_id: 1.0 for ad_id in ad_ids}
self.betas = {ad_id: 1.0 for ad_id in ad_ids}
# Track performance
self.impressions = {ad_id: 0 for ad_id in ad_ids}
self.clicks = {ad_id: 0 for ad_id in ad_ids}
def select_ad(self) -> str:
"""Select an ad using Thompson Sampling."""
= {
samples self.alphas[ad_id], self.betas[ad_id])
ad_id: np.random.beta(for ad_id in self.ad_ids
}return max(samples.items(), key=lambda x: x[1])[0]
def update(self, ad_id: str, clicked: bool):
"""Update the model based on whether the user clicked."""
self.impressions[ad_id] += 1
if clicked:
self.clicks[ad_id] += 1
self.alphas[ad_id] += 1
else:
self.betas[ad_id] += 1
def get_ctr(self, ad_id: str) -> float:
"""Calculate click-through rate for an ad."""
if self.impressions[ad_id] == 0:
return 0.0
return self.clicks[ad_id] / self.impressions[ad_id]
Running a Simulation
Let’s simulate a week of ad serving with 1000 visitors per day:
def run_simulation(ads: List[Ad],
bandit: AdBandit, int = 7,
days: int = 1000) -> pd.DataFrame:
visitors_per_day: """
Run a simulation of ad serving and track performance.
Returns:
DataFrame with daily performance metrics
"""
= []
results = datetime.now()
start_time
for day in range(days):
= start_time + timedelta(days=day)
current_time = {ad.ad_id: 0 for ad in ads}
daily_impressions = {ad.ad_id: 0 for ad in ads}
daily_clicks
for _ in range(visitors_per_day):
= bandit.select_ad()
selected_ad_id = next(ad for ad in ads if ad.ad_id == selected_ad_id)
selected_ad
= selected_ad.get_click(current_time)
clicked
bandit.update(selected_ad_id, clicked)
+= 1
daily_impressions[selected_ad_id] if clicked:
+= 1
daily_clicks[selected_ad_id]
# Record daily results
for ad in ads:
results.append({'day': day + 1,
'ad_id': ad.ad_id,
'impressions': daily_impressions[ad.ad_id],
'clicks': daily_clicks[ad.ad_id],
'ctr': daily_clicks[ad.ad_id] / daily_impressions[ad.ad_id]
if daily_impressions[ad.ad_id] > 0 else 0
})
return pd.DataFrame(results)
# Run simulation
= AdBandit([ad.ad_id for ad in ads])
bandit = run_simulation(ads, bandit)
results_df
# Plot results
=(15, 5))
plt.figure(figsize
# Daily CTR by ad
1, 2, 1)
plt.subplot(for ad_id in results_df['ad_id'].unique():
= results_df[results_df['ad_id'] == ad_id]
ad_data 'day'], ad_data['ctr'], 'o-', label=ad_id)
plt.plot(ad_data[
'Day')
plt.xlabel('Click-Through Rate')
plt.ylabel('Daily CTR by Ad')
plt.title(
plt.legend()True)
plt.grid(
# Total impressions by ad
1, 2, 2)
plt.subplot(= results_df.groupby('ad_id')['impressions'].sum()
total_impressions
plt.bar(total_impressions.index, total_impressions.values)'Ad')
plt.xlabel('Total Impressions')
plt.ylabel('Total Impressions by Ad')
plt.title(=45)
plt.xticks(rotation
plt.tight_layout() plt.show()
Analyzing the Results
Let’s calculate some key performance metrics:
# Calculate overall performance metrics
= pd.DataFrame({
performance_df 'Ad': results_df['ad_id'].unique(),
'Total Impressions': results_df.groupby('ad_id')['impressions'].sum().values,
'Total Clicks': results_df.groupby('ad_id')['clicks'].sum().values,
'Overall CTR': results_df.groupby('ad_id').apply(
lambda x: x['clicks'].sum() / x['impressions'].sum()
).values
})
= performance_df.sort_values('Overall CTR', ascending=False)
performance_df format({
performance_df.style.'Overall CTR': '{:.2%}',
'Total Impressions': '{:,.0f}',
'Total Clicks': '{:,.0f}'
})
/var/folders/nq/y2lyg7p15txfm5ksrw6f90340000gn/T/ipykernel_59721/2762817216.py:6: DeprecationWarning:
DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.
Ad | Total Impressions | Total Clicks | Overall CTR | |
---|---|---|---|---|
4 | Newsletter | 3,156 | 184 | 5.83% |
1 | Product Demo | 3,102 | 179 | 5.77% |
3 | Limited Offer | 417 | 16 | 3.84% |
2 | Brand Story | 176 | 5 | 2.84% |
0 | Summer Sale | 149 | 3 | 2.01% |
Key Insights and Implementation Tips
When implementing this system in production, consider:
- Cold Start Problem
- Start with a short exploration period for new ads
- Use prior knowledge (e.g., similar ad performance) for initial estimates
- Contextualization
- Extend the model to consider user segments
- Include time-of-day effects
- Account for seasonal variations
Here’s how to add basic contextualization:
class ContextualAdBandit:
def __init__(self, ad_ids: List[str], contexts: List[str]):
"""
Initialize contextual bandit with separate models per context.
Args:
ad_ids: List of ad identifiers
contexts: List of context identifiers (e.g., ['mobile', 'desktop'])
"""
self.bandits = {
context: AdBandit(ad_ids)for context in contexts
}
def select_ad(self, context: str) -> str:
"""Select an ad based on the current context."""
return self.bandits[context].select_ad()
def update(self, context: str, ad_id: str, clicked: bool):
"""Update the model for the given context."""
self.bandits[context].update(ad_id, clicked)
# Example usage:
= ['mobile', 'desktop']
contexts = ContextualAdBandit([ad.ad_id for ad in ads], contexts) contextual_bandit
Implementation Guidelines
- System Architecture
# Pseudocode for production implementation
class AdServer:
def __init__(self):
self.bandit = ContextualAdBandit(...)
self.cache = RedisCache(...)
async def select_ad(self, user_context):
# Get user context features
= self.extract_context(user_context)
context
# Select ad using bandit
= self.bandit.select_ad(context)
selected_ad
# Log impression
await self.log_impression(selected_ad, context)
return selected_ad
async def log_click(self, ad_id, context):
# Update model
self.bandit.update(context, ad_id, clicked=True)
# Log to analytics
await self.log_to_analytics(...)
- Monitoring Metrics
- CTR by ad and context
- Exploration rate
- Model convergence
- System latency
- Failing Gracefully
- Maintain a fallback ad selection strategy
- Cache recent model states
- Implement circuit breakers
Conclusion
Multi-armed bandits provide an elegant solution to the ad selection problem, combining:
- Automatic optimization of ad performance
- Real-time learning and adaptation
- Scalable implementation options
- Clear performance metrics
Consider starting with a simple implementation and gradually adding complexity as you validate the approach with real traffic.
References
- Chapelle, O., & Li, L. (2011). An empirical evaluation of thompson sampling. Advances in neural information processing systems, 24.
- Li, L., Chu, W., Langford, J., & Schapire, R. E. (2010). A contextual-bandit approach to personalized news article recommendation.