Skip to content

Instantly share code, notes, and snippets.

@NoRaincheck
Created January 2, 2026 10:46
Show Gist options
  • Select an option

  • Save NoRaincheck/ee2768d1a80542e5067374ab4f84a2ab to your computer and use it in GitHub Desktop.

Select an option

Save NoRaincheck/ee2768d1a80542e5067374ab4f84a2ab to your computer and use it in GitHub Desktop.
an incomplete attempt at contextual bandits with sklearn with coba
import random
from typing import Any, Dict, Hashable, List, Tuple, Union
import numpy as np
from coba.primitives import Learner
from scipy.special import softmax
from sklearn.base import BaseEstimator, clone
class BaseOnline(Learner):
def __init__(
self,
base_estimator: BaseEstimator,
replay_buffer: int = 1000000,
target_samples: int | float | None = None,
min_samples: int = 4,
) -> None:
self.base_estimator = base_estimator
self.replay_buffer = replay_buffer
self.target_samples = target_samples
self.min_samples = min_samples
self.memory_by_action: Dict[Any, List[Tuple[Any, Any, float, float]]] = {} # type: List[Tuple['Context', 'Action', 'Reward', float]]
self._classes: Dict[Any, int] = None
self._classes_is_numeric = None
self._n_classes = None
@property
def params(self) -> Dict[str, Hashable]:
base_estimator_params = self.base_estimator.get_params()
return {
"base_estimator_name": type(self.base_estimator).__name__,
"base_estimator_params": base_estimator_params,
"replay_buffer": self.replay_buffer,
"target_samples": self.target_samples,
"min_samples": self.min_samples,
"classes": self._classes,
}
def _validate_features(self, x) -> np.ndarray:
if x is None:
return np.array([0.0])
return np.array(x)
def _get_context_and_action_features(self, context: "Context", actions: "Actions"):
"""Helper method to get validated context and action features."""
context_features = self._validate_features(context)
actions_features = [self._validate_features(x) for x in actions]
return context_features, actions_features
def _get_insufficient_sample_actions(self):
"""Helper method to get actions that have insufficient samples."""
sampled_actions = []
for act in range(self.num_actions):
if len(self.memory_by_action.get(act, [])) < self.min_samples:
sampled_actions.extend([act] * (self.min_samples - len(self.memory_by_action.get(act, []))))
return sampled_actions
def _should_take_random_action(self, additional_condition: bool = True):
"""Helper method to determine if a random action should be taken."""
return (
len(self.memory_by_action) == 0
or len(self._models) == 0
or not all(
len(self.memory_by_action.get(act_idx, [])) >= self.min_samples for act_idx in range(self.num_actions)
)
or additional_condition
)
class EpsilonGreedy(BaseOnline):
def __init__(
self,
estimator: BaseEstimator,
num_actions: int,
epsilon: float = 0.05,
decay: float = 0.9,
*,
replay_buffer: int = 1000000,
target_samples: int | float | None = None,
min_samples: int = 4,
) -> None:
super().__init__(estimator, replay_buffer, target_samples, min_samples)
self.num_actions = num_actions
self.epsilon = epsilon
self.decay = decay
self._models: Dict[Any, BaseEstimator] = {}
@property
def params(self) -> Dict[str, Hashable]:
params = super().params
params.update({"epsilon": self.epsilon, "decay": self.decay})
return params
def _calculate_estimates(self, context_features: np.ndarray, actions_features: List[np.ndarray]):
"""Calculate estimates for all actions."""
estimates = {}
for act_idx in range(self.num_actions):
model = self._models.get(act_idx, None)
if model is None:
continue
action_features = actions_features[act_idx]
estimates[act_idx] = model.predict(np.hstack([context_features, action_features]).reshape(1, -1))[0]
return estimates
def score(self, context: "Context", actions: "Actions"):
context_features, actions_features = self._get_context_and_action_features(context, actions)
estimates = self._calculate_estimates(context_features, actions_features)
if len(estimates) == 0:
return None
best_action = max(estimates, key=estimates.get)
return best_action
def predict(self, context: "Context", actions: "Actions"):
context_features, actions_features = self._get_context_and_action_features(context, actions)
if self._should_take_random_action(np.random.rand() < self.epsilon):
if not all(
len(self.memory_by_action.get(act_idx, [])) >= self.min_samples for act_idx in range(self.num_actions)
):
sampled_actions = self._get_insufficient_sample_actions()
selected_action = np.random.choice(sampled_actions)
else:
selected_action = random.choice(range(self.num_actions))
else:
estimates = self._calculate_estimates(context_features, actions_features)
best_action = max(estimates, key=estimates.get)
selected_action = best_action
return (
actions[selected_action],
{"selected_action": selected_action},
)
def learn(
self,
context: "Context",
action: "Action",
reward: "Reward",
probability: float,
selected_action: int,
) -> None:
context_features = self._validate_features(context)
action_features = self._validate_features(action)
if selected_action not in self.memory_by_action:
self.memory_by_action[selected_action] = []
if selected_action not in self._models:
self._models[selected_action] = clone(self.base_estimator)
self.memory_by_action[selected_action].append((np.hstack([context_features, action_features]), reward))
if len(self.memory_by_action[selected_action]) > self.replay_buffer:
# Remove the oldest action
self.memory_by_action[selected_action].pop(0)
if len(self.memory_by_action[selected_action]) >= self.min_samples:
self._models[selected_action] = clone(self.base_estimator, safe=True)
X = np.vstack([x[0] for x in self.memory_by_action[selected_action]])
y = np.hstack([x[1] for x in self.memory_by_action[selected_action]])
self._models[selected_action].fit(X, y)
# only decay if we have sampled enough
if all(len(mem) >= self.min_samples for mem in self.memory_by_action.values()):
self.epsilon *= self.decay
class BootstrappedUCB(BaseOnline):
def __init__(
self,
estimator: BaseEstimator,
num_actions: int,
num_bootstraps: int = 10,
percentile: int = 80,
*,
replay_buffer: int = 1000000,
target_samples: int | float | None = None,
min_samples: int = 4,
bootstrap_ratio: float = 0.5,
use_importance_weights: bool = False,
) -> None:
super().__init__(estimator, replay_buffer, target_samples, min_samples)
self.num_actions = num_actions
self.num_bootstraps = num_bootstraps
self.percentile = percentile
self.bootstrap_ratio = bootstrap_ratio
self.use_importance_weights = use_importance_weights
self._models: Dict[int, List[BaseEstimator]] = {}
@property
def params(self) -> Dict[str, Hashable]:
params = super().params
params.update(
{
"num_bootstraps": self.num_bootstraps,
"percentile": self.percentile,
"bootstrap_ratio": self.bootstrap_ratio,
"use_importance_weights": self.use_importance_weights,
}
)
return params
def _calculate_estimates(self, context_features: np.ndarray, actions_features: List[np.ndarray]):
"""Calculate estimates for all actions using bootstrapped models."""
estimates = {}
for act_idx in range(self.num_actions):
models = self._models.get(act_idx, None)
if models is None:
continue
action_features = actions_features[act_idx]
bootstrapped_estimates = []
for model in models:
bootstrapped_estimates.append(
model.predict(np.hstack([context_features, action_features]).reshape(1, -1))[0]
)
estimates[act_idx] = np.percentile(bootstrapped_estimates, self.percentile)
return estimates
def score(self, context: "Context", actions: "Actions"):
context_features, actions_features = self._get_context_and_action_features(context, actions)
estimates = self._calculate_estimates(context_features, actions_features)
if len(estimates) == 0:
return None
best_action = max(estimates, key=estimates.get)
return best_action
def predict(self, context: "Context", actions: "Actions"):
context_features, actions_features = self._get_context_and_action_features(context, actions)
if self._should_take_random_action():
if not all(
len(self.memory_by_action.get(act_idx, [])) >= self.min_samples for act_idx in range(self.num_actions)
):
sampled_actions = self._get_insufficient_sample_actions()
selected_action = np.random.choice(sampled_actions)
else:
selected_action = random.choice(range(self.num_actions))
else:
estimates = self._calculate_estimates(context_features, actions_features)
best_action = max(estimates, key=estimates.get)
selected_action = best_action
return (
actions[selected_action],
{"selected_action": selected_action},
)
def learn(
self,
context: "Context",
action: "Action",
reward: "Reward",
probability: float,
selected_action: int,
) -> None:
context_features = self._validate_features(context)
action_features = self._validate_features(action)
if selected_action not in self.memory_by_action:
self.memory_by_action[selected_action] = []
if selected_action not in self._models:
self._models[selected_action] = [
clone(self.base_estimator)
] # Initialize with a single estimator that will be replaced
self.memory_by_action[selected_action].append((np.hstack([context_features, action_features]), reward))
if len(self.memory_by_action[selected_action]) > self.replay_buffer:
# Remove the oldest action
self.memory_by_action[selected_action].pop(0)
if len(self.memory_by_action[selected_action]) >= self.min_samples:
X = np.vstack([x[0] for x in self.memory_by_action[selected_action]])
y = np.hstack([x[1] for x in self.memory_by_action[selected_action]])
n_bootstrap_samples = int(len(X) * self.bootstrap_ratio)
models = []
for idx in range(self.num_bootstraps):
if self.use_importance_weights and len(self._models[selected_action]) == self.num_bootstraps:
weights = self._models[selected_action][idx].predict(X)
weights = softmax(weights / 0.01, axis=0)
bootstrap_indices = np.random.choice(
range(len(X)), size=n_bootstrap_samples, replace=True, p=weights
)
else:
bootstrap_indices = np.random.choice(range(len(X)), size=n_bootstrap_samples, replace=True)
X_bootstrap = X[bootstrap_indices]
y_bootstrap = y[bootstrap_indices]
model = clone(self.base_estimator, safe=True)
model.fit(X_bootstrap, y_bootstrap)
models.append(model)
self._models[selected_action] = models
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment