Last active
July 6, 2022 07:35
-
-
Save morrisalp/5aa3788b090990570a5cb83d5a45fbc8 to your computer and use it in GitHub Desktop.
Demo for illustrating multi-armed bandit algorithms. The methods of the class Agent are placeholders and may be changed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class Agent: | |
def __init__(self, n_machines=10): | |
self.t = 0 # current time step | |
self.n_machines = n_machines | |
def pick_machine(self): | |
# returns index of machine | |
return np.random.randint(0, self.n_machines) | |
def get_reward(self, reward, machine_index): | |
self.t += 1 | |
class Environment: | |
def __init__(self, n_machines=10): | |
self.n_machines = n_machines | |
# Machine i has reward distribution Bernoulli(p_i) | |
# The p_i's are drawn from Uniform([0, 1]) | |
self.params = np.random.uniform(size=n_machines) | |
def _interact(self, machine_index): | |
assert 0 <= machine_index < self.n_machines, 'Bad machine index' | |
p = self.params[machine_index] | |
# Sample from Bernoulli(p) to get reward: | |
# (i.e. Binomial distribution with n=1) | |
reward = np.random.binomial(n=1, p=p) | |
return reward | |
def run(self, time_steps=250): | |
total_reward = 0 | |
agent = Agent() | |
for _ in range(time_steps): | |
machine_index = agent.pick_machine() | |
reward = self._interact(machine_index) | |
agent.get_reward(reward, machine_index) | |
total_reward += reward | |
return total_reward | |
if __name__ == "__main__": | |
np.random.seed(0) | |
print(Environment().run()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment