Skip to content

Instantly share code, notes, and snippets.

@jxnl
Last active August 29, 2015 14:01
Show Gist options
  • Save jxnl/955f936ca034c90ba6f9 to your computer and use it in GitHub Desktop.
Save jxnl/955f936ca034c90ba6f9 to your computer and use it in GitHub Desktop.
"""
Author: Jason Liu
"""
import random
class Bandit(object):
"""An epsilon greedy bandit to search the probabilities of arms of unknown
probabilities.
Attributes:
n (int): Number of arms the bandit has access too
arms (list): A list of floats between [0-1] tha represent the true probs
epsilon (float): Between [0-1] that is a measure of greedyness
alpha (float): Heuristic value for decaying epsilon
rewards (list): Container for rewards obtained from an arm
count (list): Counter for the number of times an arm was updated
"""
def __init__(self, arms, epsilon=0.80, alpha=0.80):
self.arms = arms
self.size = len(arms)
self.alpha = alpha
self.epsilon = epsilon
self.rewards = [0.0] * self.size
self.count = [0] * self.size
def sample_arm(self):
"""Explore a random arm or exploit a well studied one."""
if random.random() > self.epsilon:
self.update(self.best_arm)
else:
random_arm = random.randint(0, self.size - 1)
self.update(random_arm)
@property
def best_arm(self):
"""Returns index of best arm"""
return self.rewards.index(max(self.rewards))
def update(self, arm):
"""Play arm and hope for a reward
attributes:
arm (int): Index of the arm you wish you play
"""
if random.random() < self.arms[arm]:
self.rewards[arm] += 1
self.count[arm] += 1
self.epsilon *= self.alpha
def summary(self):
"""Print a summary of results"""
t = """Bandit Arms : {self.arms}\nBest Arm at : {best_arm}"""
print(t.format(self=self, best_arm=self.best_arm))
def main():
arms = [0.8, 0.50] # think of these as two coins of p(heads) = .8 or .5
algo = Bandit(arms=arms, epsilon=1, alpha=0.995)
for _ in range(40):
# you have 40 tries to find the best coint
algo.sample_arm()
algo.summary()
print(algo.count, algo.epsilon)
if __name__ == "__main__":
main()
import java.util.Random;
/**
* Implements the Epsilon Greedy Bandit algorithm for finding
* optimal probabilities of success given k binomial distributions.
*/
public class EpsilonGreedyBandit {
private final double [] SAMPLE_ARMS;
private final float GREEDYNESS;
private int[] numberOfSuccesses;
private int[] numberOfSamples;
private double[] maximumLikelihoodEstimate;
private Random randomGenerator = new Random();
public static void main(String[] args){
double[] newArms = {0.1, 0.2, 0.8, 0.4, 0.5};
EpsilonGreedyBandit sampler = new EpsilonGreedyBandit(newArms, 0.90F);
sampler.sampleArms(1000);
for (int i=0; i<newArms.length; i++){
System.out.println(String.format("Arm number %d has win rate %1.1f", i, newArms[i]));
System.out.println(String.format("--- Sampled %d times\n--- Succeeded %d times", sampler.numberOfSamples[i], sampler.numberOfSuccesses[i]));
System.out.println("--- maximumLikelihoodEstimate:" + sampler.maximumLikelihoodEstimate[i] + "\n\n");
}
System.out.println("The best arm is #"+sampler.bestArm());
}
public EpsilonGreedyBandit(double[] arms, float g){
maximumLikelihoodEstimate = new double[arms.length];
numberOfSuccesses = new int[arms.length];
numberOfSamples = new int[arms.length];
GREEDYNESS = g;
SAMPLE_ARMS = arms;
}
/**
* Sample arms.
* Either uniformly explore arms or exploit the best arm
*
* @param numberOfSamples number of times we should sample the arms
*/
public void sampleArms(int numberOfSamples){
for (int i=0; i < numberOfSamples; i++){
// Exploitation
if (randomGenerator.nextFloat() < GREEDYNESS){
update(bestArm());
// Exploration
} else {
int target = randomGenerator.nextInt(SAMPLE_ARMS.length);
update(target);
}
}
}
/**
* Determine the best arm.
* Based off of maximum likelihood estimate * for a binomial distribution.
*
* @return index of arm with highest maximumLikelihoodEstimate
*/
public int bestArm(){
int bestIndex = randomGenerator.nextInt(SAMPLE_ARMS.length);
for (int i=0; i<numberOfSuccesses.length; i++){
if (maximumLikelihoodEstimate[i] > maximumLikelihoodEstimate[bestIndex]){
bestIndex = i;
}
}
return bestIndex;
}
/**
* Update the selected arm.
* Increment numberOfSuccesses if the target arm returns a successful event
*
* @param targetArm Index of arm to be updated
*/
private void update(int targetArm){
numberOfSamples[targetArm]++;
if (randomGenerator.nextFloat()< SAMPLE_ARMS[targetArm]){
numberOfSuccesses[targetArm]++;
}
maximumLikelihoodEstimate[targetArm] = (double)numberOfSuccesses[targetArm] /numberOfSamples[targetArm];
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment