Last active
September 28, 2021 13:43
-
-
Save comckay/74a98a30911e03a47ef1340e15e3bc1d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import numpy as np | |
class UCB1: | |
def __init__(self, models: List[str]): | |
self.models, n_models = models, len(models) | |
self.model_successes = np.zeros((n_models)) | |
self.model_tries = np.zeros((n_models)) | |
def _increment_model_tries(self, model: str) -> None: | |
self.model_tries[self.models.index(model)] += 1 | |
def _get_model_with_max_ucb(self) -> str: | |
ucb_numerator = 2 * np.log(np.sum(self.model_tries)) | |
per_model_means = self.model_successes / self.model_tries | |
ucb1_estimates = per_model_means + np.sqrt(ucb_numerator / self.model_tries) | |
return self.models[np.nanargmax(ucb1_estimates)] | |
def select_model(self) -> str: | |
untested_models = np.nonzero(self.model_tries == 0)[0] | |
if untested_models.size == 0: | |
best_model_so_far = self._get_model_with_max_ucb() | |
self._increment_model_tries(best_model_so_far) | |
return best_model_so_far | |
else: | |
untested_model = self.models[untested_models[0]] | |
self._increment_model_tries(untested_model) | |
return untested_model | |
def reward_model(self, model: str) -> None: | |
if model not in self.models: | |
raise ValueError(f"model {model} not recognized") | |
model_index = self.models.index(model) | |
self.model_successes[model_index] += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment