Skip to content

Instantly share code, notes, and snippets.

@morrisalp
Last active July 6, 2022 07:35
Show Gist options
  • Save morrisalp/5aa3788b090990570a5cb83d5a45fbc8 to your computer and use it in GitHub Desktop.
Save morrisalp/5aa3788b090990570a5cb83d5a45fbc8 to your computer and use it in GitHub Desktop.
Demo for illustrating multi-armed bandit algorithms. The methods of the class Agent are placeholders and may be changed.
import numpy as np
class Agent:
def __init__(self, n_machines=10):
self.t = 0 # current time step
self.n_machines = n_machines
def pick_machine(self):
# returns index of machine
return np.random.randint(0, self.n_machines)
def get_reward(self, reward, machine_index):
self.t += 1
class Environment:
def __init__(self, n_machines=10):
self.n_machines = n_machines
# Machine i has reward distribution Bernoulli(p_i)
# The p_i's are drawn from Uniform([0, 1])
self.params = np.random.uniform(size=n_machines)
def _interact(self, machine_index):
assert 0 <= machine_index < self.n_machines, 'Bad machine index'
p = self.params[machine_index]
# Sample from Bernoulli(p) to get reward:
# (i.e. Binomial distribution with n=1)
reward = np.random.binomial(n=1, p=p)
return reward
def run(self, time_steps=250):
total_reward = 0
agent = Agent()
for _ in range(time_steps):
machine_index = agent.pick_machine()
reward = self._interact(machine_index)
agent.get_reward(reward, machine_index)
total_reward += reward
return total_reward
if __name__ == "__main__":
np.random.seed(0)
print(Environment().run())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment