Created
September 29, 2018 19:21
-
-
Save conormm/277f0290478aecb43505da7c7a5959bc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ThompsonSampler(BaseSampler): | |
def __init__(self, env): | |
super().__init__(env) | |
def choose_k(self): | |
# sample from posterior (this is the thompson sampling approach) | |
# this leads to more exploration because machines with > uncertainty can then be selected as the machine | |
self.theta = np.random.beta(self.a, self.b) | |
# select machine with highest posterior p of payout | |
self.k = self.variants[np.argmax(self.theta)] | |
return self.k | |
def update(self): | |
#update dist (a, b) = (a, b) + (r, 1 - r) | |
self.a[self.k] += self.reward | |
self.b[self.k] += 1 - self.reward # i.e. only increment b when it's a swing and a miss. 1 - 0 = 1, 1 - 1 = 0 | |
self.thetas[self.i] = self.theta[self.k] | |
self.thetaregret[self.i] = np.max(self.thetas) - self.theta[self.k] | |
self.ad_i[self.i] = self.k | |
self.r_i[self.i] = self.reward | |
self.i += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment