Created
December 28, 2012 16:18
-
-
Save dehowell/4399270 to your computer and use it in GitHub Desktop.
Multi-armed bandit simulation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import csv | |
import math | |
import random | |
import sys | |
import numpy | |
NUMBER_OF_LEVERS = 10 | |
class BaseLever(object): | |
def mean(self): | |
pass | |
def won(self): | |
return random.random() <= self.mean() | |
class FixedLever(BaseLever): | |
def __init__(self): | |
self.m = random.random() | |
def mean(self): | |
return self.m | |
class FickleLever(FixedLever): | |
def mean(self): | |
if random.random() < .1: | |
self.m = random.random() | |
return self.m | |
LEVERS = [FickleLever() for n in xrange(NUMBER_OF_LEVERS)] | |
# def regret(rewards, round_number): | |
# optimal = math.ceil(max(MEANS) * round_number) | |
# return optimal - rewards | |
class Strategy(object): | |
'''Base class''' | |
def __init__(self): | |
self.memory = numpy.zeros((NUMBER_OF_LEVERS, 2)) | |
self.ratios = numpy.zeros(NUMBER_OF_LEVERS) | |
def choose_lever(self): | |
pass | |
def random_lever(self): | |
return random.randint(0, NUMBER_OF_LEVERS - 1) | |
def best_lever(self): | |
return sorted(range(NUMBER_OF_LEVERS), key=lambda k: self.ratios[k])[-1] | |
def play(self, round_number): | |
lever = self.choose_lever(round_number) | |
if LEVERS[lever].won(): | |
self.memory[lever] += numpy.array([0, 1]) | |
else: | |
self.memory[lever] += numpy.array([1, 0]) | |
self.ratios[lever] = self.memory[lever][1] / sum(self.memory[lever]) | |
return self.rewards() | |
def rewards(self): | |
return sum(self.memory[:,1]) | |
class RoundRobin(Strategy): | |
def choose_lever(self, round_number): | |
return round_number % NUMBER_OF_LEVERS | |
class EpsilonFirst(Strategy): | |
def __init__(self, explore): | |
self.explore = explore | |
super(EpsilonFirst, self).__init__() | |
def choose_lever(self, round_number): | |
if round_number <= self.explore: | |
return self.random_lever() | |
else: | |
return self.best_lever() | |
class EpsilonGreedy(Strategy): | |
def __init__(self, epsilon): | |
self.epsilon = epsilon | |
super(EpsilonGreedy, self).__init__() | |
def choose_lever(self, round_number): | |
if random.random() <= self.epsilon: | |
return self.random_lever() | |
else: | |
return self.best_lever() | |
STRATEGIES = { | |
'roundrobin': RoundRobin(), | |
'epsilonfirst-10': EpsilonFirst(10), | |
'epsilonfirst-100': EpsilonFirst(100), | |
'epsilongreedy-0.1': EpsilonGreedy(0.1), | |
'epsilongreedy-0.01': EpsilonGreedy(0.01) | |
} | |
writer = csv.writer(sys.stdout) | |
writer.writerow(['round', 'strategy', 'rewards']) | |
for n in xrange(1, 3000 + 1): | |
for strategy in STRATEGIES.keys(): | |
rewards = STRATEGIES[strategy].play(n) | |
writer.writerow([n, strategy, rewards]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment