Last active
August 29, 2015 14:01
-
-
Save jxnl/955f936ca034c90ba6f9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Author: Jason Liu | |
| """ | |
| import random | |
| class Bandit(object): | |
| """An epsilon greedy bandit to search the probabilities of arms of unknown | |
| probabilities. | |
| Attributes: | |
| n (int): Number of arms the bandit has access too | |
| arms (list): A list of floats between [0-1] tha represent the true probs | |
| epsilon (float): Between [0-1] that is a measure of greedyness | |
| alpha (float): Heuristic value for decaying epsilon | |
| rewards (list): Container for rewards obtained from an arm | |
| count (list): Counter for the number of times an arm was updated | |
| """ | |
| def __init__(self, arms, epsilon=0.80, alpha=0.80): | |
| self.arms = arms | |
| self.size = len(arms) | |
| self.alpha = alpha | |
| self.epsilon = epsilon | |
| self.rewards = [0.0] * self.size | |
| self.count = [0] * self.size | |
| def sample_arm(self): | |
| """Explore a random arm or exploit a well studied one.""" | |
| if random.random() > self.epsilon: | |
| self.update(self.best_arm) | |
| else: | |
| random_arm = random.randint(0, self.size - 1) | |
| self.update(random_arm) | |
| @property | |
| def best_arm(self): | |
| """Returns index of best arm""" | |
| return self.rewards.index(max(self.rewards)) | |
| def update(self, arm): | |
| """Play arm and hope for a reward | |
| attributes: | |
| arm (int): Index of the arm you wish you play | |
| """ | |
| if random.random() < self.arms[arm]: | |
| self.rewards[arm] += 1 | |
| self.count[arm] += 1 | |
| self.epsilon *= self.alpha | |
| def summary(self): | |
| """Print a summary of results""" | |
| t = """Bandit Arms : {self.arms}\nBest Arm at : {best_arm}""" | |
| print(t.format(self=self, best_arm=self.best_arm)) | |
| def main(): | |
| arms = [0.8, 0.50] # think of these as two coins of p(heads) = .8 or .5 | |
| algo = Bandit(arms=arms, epsilon=1, alpha=0.995) | |
| for _ in range(40): | |
| # you have 40 tries to find the best coint | |
| algo.sample_arm() | |
| algo.summary() | |
| print(algo.count, algo.epsilon) | |
| if __name__ == "__main__": | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import java.util.Random; | |
| /** | |
| * Implements the Epsilon Greedy Bandit algorithm for finding | |
| * optimal probabilities of success given k binomial distributions. | |
| */ | |
| public class EpsilonGreedyBandit { | |
| private final double [] SAMPLE_ARMS; | |
| private final float GREEDYNESS; | |
| private int[] numberOfSuccesses; | |
| private int[] numberOfSamples; | |
| private double[] maximumLikelihoodEstimate; | |
| private Random randomGenerator = new Random(); | |
| public static void main(String[] args){ | |
| double[] newArms = {0.1, 0.2, 0.8, 0.4, 0.5}; | |
| EpsilonGreedyBandit sampler = new EpsilonGreedyBandit(newArms, 0.90F); | |
| sampler.sampleArms(1000); | |
| for (int i=0; i<newArms.length; i++){ | |
| System.out.println(String.format("Arm number %d has win rate %1.1f", i, newArms[i])); | |
| System.out.println(String.format("--- Sampled %d times\n--- Succeeded %d times", sampler.numberOfSamples[i], sampler.numberOfSuccesses[i])); | |
| System.out.println("--- maximumLikelihoodEstimate:" + sampler.maximumLikelihoodEstimate[i] + "\n\n"); | |
| } | |
| System.out.println("The best arm is #"+sampler.bestArm()); | |
| } | |
| public EpsilonGreedyBandit(double[] arms, float g){ | |
| maximumLikelihoodEstimate = new double[arms.length]; | |
| numberOfSuccesses = new int[arms.length]; | |
| numberOfSamples = new int[arms.length]; | |
| GREEDYNESS = g; | |
| SAMPLE_ARMS = arms; | |
| } | |
| /** | |
| * Sample arms. | |
| * Either uniformly explore arms or exploit the best arm | |
| * | |
| * @param numberOfSamples number of times we should sample the arms | |
| */ | |
| public void sampleArms(int numberOfSamples){ | |
| for (int i=0; i < numberOfSamples; i++){ | |
| // Exploitation | |
| if (randomGenerator.nextFloat() < GREEDYNESS){ | |
| update(bestArm()); | |
| // Exploration | |
| } else { | |
| int target = randomGenerator.nextInt(SAMPLE_ARMS.length); | |
| update(target); | |
| } | |
| } | |
| } | |
| /** | |
| * Determine the best arm. | |
| * Based off of maximum likelihood estimate * for a binomial distribution. | |
| * | |
| * @return index of arm with highest maximumLikelihoodEstimate | |
| */ | |
| public int bestArm(){ | |
| int bestIndex = randomGenerator.nextInt(SAMPLE_ARMS.length); | |
| for (int i=0; i<numberOfSuccesses.length; i++){ | |
| if (maximumLikelihoodEstimate[i] > maximumLikelihoodEstimate[bestIndex]){ | |
| bestIndex = i; | |
| } | |
| } | |
| return bestIndex; | |
| } | |
| /** | |
| * Update the selected arm. | |
| * Increment numberOfSuccesses if the target arm returns a successful event | |
| * | |
| * @param targetArm Index of arm to be updated | |
| */ | |
| private void update(int targetArm){ | |
| numberOfSamples[targetArm]++; | |
| if (randomGenerator.nextFloat()< SAMPLE_ARMS[targetArm]){ | |
| numberOfSuccesses[targetArm]++; | |
| } | |
| maximumLikelihoodEstimate[targetArm] = (double)numberOfSuccesses[targetArm] /numberOfSamples[targetArm]; | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment