Created
September 29, 2018 19:19
-
-
Save conormm/a75ea71b2e54b6ad30e48586f8bcea87 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class eGreedy(BaseSampler): | |
def __init__(self, env, n_learning, e): | |
super().__init__(env, n_learning, e) | |
def choose_k(self): | |
# e% of the time take a random draw from machines | |
# random k for n learning trials, then the machine with highest theta | |
self.k = np.random.choice(self.variants) if self.i < self.n_learning else np.argmax(self.theta) | |
# with 1 - e probability take a random sample (explore) otherwise exploit | |
self.k = np.random.choice(self.variants) if self.ep[self.i] > self.exploit else self.k | |
return self.k | |
def update(self): | |
# update the probability of payout for each machine | |
self.a[self.k] += self.reward | |
self.b[self.k] += 1 | |
self.theta = self.a/self.b | |
self.thetas[self.i] = self.theta[self.k] | |
self.thetaregret[self.i] = np.max(self.thetas) - self.theta[self.k] | |
self.ad_i[self.i] = self.k | |
self.r_i[self.i] = self.reward | |
self.i += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment