Last active
April 8, 2021 19:59
-
-
Save conormm/976e3162ca3eb9db990f2825543daf30 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ThompsonSampler: | |
"""Thompson Sampling using a Beta distribution associated with each option. | |
The beta distribution will be updated when rewards associated with each option | |
are observed. | |
""" | |
def __init__(self, env, n_learning=0): | |
# boilier plate data storage | |
self.env = env | |
self.n_learning = n_learning | |
self.options = env.options | |
self.n_trials = env.n_trials | |
self.payouts = env.payouts | |
self.option_i = np.zeros(env.n_trials) | |
self.r_i = np.zeros(env.n_trials) | |
self.thetas = np.zeros(self.n_trials) | |
self.data = None | |
self.reward = 0 | |
self.total_reward = 0 | |
self.option = 0 | |
self.trial = 0 | |
# parameters of beta distribution | |
self.alpha = np.ones(env.n_options) | |
self.beta = np.ones(env.n_options) | |
# estimated payout rates | |
self.theta = np.zeros(env.n_options) | |
def choose_option(self): | |
"""Select the option to be presented. | |
""" | |
# sample from posterior (this is the thompson sampling approach) | |
# this leads to more exploration because machines with > uncertainty can then be selected as the machine | |
self.theta = np.random.beta(self.alpha, self.beta) | |
# select machine with highest posterior p of payout | |
if self.trial < self.n_learning: | |
self.option = np.random.choice(self.options) | |
else: | |
self.option = self.options[np.argmax(self.theta)] | |
return self.option | |
def update_model(self): | |
"""Update the parameters of each option's distribution with the observed | |
successes and failures. | |
""" | |
#update dist (a, b) = (a, b) + (r, 1 - r) | |
# a,b are the alpha, beta parameters of a Beta distribution | |
self.alpha[self.option] += self.reward | |
# i.e. only increment b when it's a swing and a miss. 1 - 0 = 1, 1 - 1 = 0 | |
self.beta[self.option] += 1 - self.reward | |
# store the option presented on each trial | |
self.thetas[self.trial] = self.theta[self.option] | |
self.option_i[self.trial] = self.option | |
self.r_i[self.trial] = self.reward | |
self.trial += 1 | |
def collect_data(self): | |
self.data = pd.DataFrame(dict(option=self.option_i, reward=self.r_i)) | |
def __str__(self): | |
return "ThompsonSampler" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment