Skip to content

Instantly share code, notes, and snippets.

@jxnl
Created October 17, 2014 00:37
Show Gist options
  • Save jxnl/c3fba177efed9593d5c1 to your computer and use it in GitHub Desktop.
Save jxnl/c3fba177efed9593d5c1 to your computer and use it in GitHub Desktop.
Multivariate testing
"""
Author: Jason Liu
"""
import random
class BernoulliArm(object):
"""An arm that either returns 0 or 1 as it's reward"""
def __init__(self, name, probability=None):
self.name = name
if probability:
self.p = probability
else:
self.p = random.random()
def draw(self):
return random.random() < self.p
class GreedyBandit(object):
"""An epsilon greedy bandit to search the probabilities of arms of unknown
probabilities.
Attributes:
arms (list): A list of floats between [0-1] tha represent the true probs
epsilon (float): Between [0-1] that is a measure of greedyness
alpha (float): Heuristic value for decaying epsilon
rewards (list): Container for rewards obtained from an arm
count (list): Counter for the number of times an arm was updated
"""
def __init__(self, arms, epsilon=0.80, alpha=0.80):
self.arms = arms
self.size = len(arms)
self.alpha = alpha
self.epsilon = epsilon
self.rewards = [0.0] * self.size
self.count = [0] * self.size
def sample(self):
"""Explore a random arm or exploit a well studied one."""
if random.random() > self.epsilon:
self.update(self.best_arm)
else:
random_arm = random.randint(0, self.size - 1)
self.update(random_arm)
@property
def best_arm(self):
"""Returns index of best arm"""
return self.rewards.index(max(self.rewards))
def update(self, idx):
"""Play an arm and hope for a reward"""
self.count[idx] += 1
self.rewards[idx] += self.arms[idx].draw()
# Annealing heuristic
self.epsilon *= self.alpha
def summary(self):
"""Print a summary of results"""
t1 = "Available Arms : {}"
t2 = "Best Arm Was : {}"
t3 = "Samples Where : {}"
t4 = "Rewards Where : {}"
t5 = "Total rewards : {}"
t6 = "Expected 50/50 : {}"
t7 = "Probabilities : {}"
t8 = "Sampled Probs : {}"
print(t1.format([i.name for i in self.arms]))
print(t2.format(self.arms[self.best_arm].name))
print(t3.format([c for c in self.count]))
print(t4.format([r for r in self.rewards]))
print(t5.format(sum([r for r in self.rewards])))
print(t6.format(sum(map(lambda (x, y): x*y,
zip(self.rewards, list(i.p for i in self.arms))))))
print(t7.format(list(i.p for i in self.arms)))
print(t8.format(map(lambda (x, y): x/y,
zip(self.rewards, self.count))))
def main():
arms = [BernoulliArm("red button"),
BernoulliArm("blue button")]
algo = GreedyBandit(arms, epsilon=1, alpha=0.995)
for _ in range(40):
# you have 40 tries to find the best coint
algo.sample()
algo.summary()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment