Created
October 17, 2014 00:37
-
-
Save jxnl/c3fba177efed9593d5c1 to your computer and use it in GitHub Desktop.
Multivariate testing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Author: Jason Liu | |
| """ | |
| import random | |
| class BernoulliArm(object): | |
| """An arm that either returns 0 or 1 as it's reward""" | |
| def __init__(self, name, probability=None): | |
| self.name = name | |
| if probability: | |
| self.p = probability | |
| else: | |
| self.p = random.random() | |
| def draw(self): | |
| return random.random() < self.p | |
| class GreedyBandit(object): | |
| """An epsilon greedy bandit to search the probabilities of arms of unknown | |
| probabilities. | |
| Attributes: | |
| arms (list): A list of floats between [0-1] tha represent the true probs | |
| epsilon (float): Between [0-1] that is a measure of greedyness | |
| alpha (float): Heuristic value for decaying epsilon | |
| rewards (list): Container for rewards obtained from an arm | |
| count (list): Counter for the number of times an arm was updated | |
| """ | |
| def __init__(self, arms, epsilon=0.80, alpha=0.80): | |
| self.arms = arms | |
| self.size = len(arms) | |
| self.alpha = alpha | |
| self.epsilon = epsilon | |
| self.rewards = [0.0] * self.size | |
| self.count = [0] * self.size | |
| def sample(self): | |
| """Explore a random arm or exploit a well studied one.""" | |
| if random.random() > self.epsilon: | |
| self.update(self.best_arm) | |
| else: | |
| random_arm = random.randint(0, self.size - 1) | |
| self.update(random_arm) | |
| @property | |
| def best_arm(self): | |
| """Returns index of best arm""" | |
| return self.rewards.index(max(self.rewards)) | |
| def update(self, idx): | |
| """Play an arm and hope for a reward""" | |
| self.count[idx] += 1 | |
| self.rewards[idx] += self.arms[idx].draw() | |
| # Annealing heuristic | |
| self.epsilon *= self.alpha | |
| def summary(self): | |
| """Print a summary of results""" | |
| t1 = "Available Arms : {}" | |
| t2 = "Best Arm Was : {}" | |
| t3 = "Samples Where : {}" | |
| t4 = "Rewards Where : {}" | |
| t5 = "Total rewards : {}" | |
| t6 = "Expected 50/50 : {}" | |
| t7 = "Probabilities : {}" | |
| t8 = "Sampled Probs : {}" | |
| print(t1.format([i.name for i in self.arms])) | |
| print(t2.format(self.arms[self.best_arm].name)) | |
| print(t3.format([c for c in self.count])) | |
| print(t4.format([r for r in self.rewards])) | |
| print(t5.format(sum([r for r in self.rewards]))) | |
| print(t6.format(sum(map(lambda (x, y): x*y, | |
| zip(self.rewards, list(i.p for i in self.arms)))))) | |
| print(t7.format(list(i.p for i in self.arms))) | |
| print(t8.format(map(lambda (x, y): x/y, | |
| zip(self.rewards, self.count)))) | |
| def main(): | |
| arms = [BernoulliArm("red button"), | |
| BernoulliArm("blue button")] | |
| algo = GreedyBandit(arms, epsilon=1, alpha=0.995) | |
| for _ in range(40): | |
| # you have 40 tries to find the best coint | |
| algo.sample() | |
| algo.summary() | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment