Skip to content

Instantly share code, notes, and snippets.

@danielhavir
Last active July 8, 2022 18:25
Show Gist options
  • Save danielhavir/2097139678ad8348fdf3c078321f96c2 to your computer and use it in GitHub Desktop.
Save danielhavir/2097139678ad8348fdf3c078321f96c2 to your computer and use it in GitHub Desktop.
ε-greedy action selection, sample-average action-value estimates
import argparse
import multiprocessing as mp
from functools import partial
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
STEPS = 10000
class KArmedBandit:
def __init__(self, k, sigma=1, non_stationary=False):
self.k = k
self.mus = np.random.uniform(-2, 2, (k,))
self.sigmas = np.ones(k) * sigma
self.non_stationary = non_stationary
self._step = 0
print("μ:", self.mus)
def step(self, a):
self._step += 1
r = np.random.normal() * self.sigmas[a] + self.mus[a]
if self.non_stationary:
self.mus += np.random.normal(size=self.mus.shape[0]) * 0.01
done = self._step == STEPS
return self._step, r, done
def reset(self):
self._step = 0
class Agent:
def __init__(
self,
num_actions,
epsilon,
min_each_action=1,
interpolate_eps=False,
step_size=None,
):
if interpolate_eps:
self.epsilon = np.linspace(epsilon, 0, STEPS)
else:
self.epsilon = epsilon
self.interpolate_eps = interpolate_eps
self.min_each_action = min_each_action
self.step_size = step_size
self.num_action_taken = np.zeros(num_actions)
self.qs = np.zeros(num_actions)
self.reward_collected = []
self._step = 0
@property
def total_reward(self):
if len(self.reward_collected) == 0:
return 0
else:
return self.reward_collected[-1]
def sample(self):
if np.min(self.num_action_taken) < self.min_each_action:
return np.argmin(self.num_action_taken)
p = np.random.uniform(0, 1)
epsilon = self.epsilon[self._step] if self.interpolate_eps else self.epsilon
if p <= epsilon:
return np.random.randint(0, self.qs.shape[0])
self._step += 1
return np.argmax(self.qs)
def update(self, a, r):
self.num_action_taken[a] += 1
step_size = (
1 / self.num_action_taken[a] if self.step_size is None else self.step_size
)
self.qs[a] += step_size * (r - self.qs[a])
if len(self.reward_collected) == 0:
self.reward_collected.append(r)
else:
self.reward_collected.append(self.reward_collected[-1] + r)
def episode(episode_num, env, **kwargs):
env.reset()
agent = Agent(**kwargs)
done = False
trajectory = []
while not done:
a = agent.sample()
trajectory.append(a)
step, r, done = env.step(a)
agent.update(a, r)
print(f"Episode {episode_num} reward {agent.total_reward}")
return trajectory, agent.reward_collected
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--min_each_action", type=int, default=1)
parser.add_argument("-e", "--epsilon", type=float, default=0.0)
parser.add_argument("-a", "--alpha", type=float, default=None)
parser.add_argument("-s", "--env_sigma", type=float, default=1.0)
parser.add_argument("--interpolate_eps", action="store_true")
parser.add_argument("--non_stationary_env", action="store_true")
args = parser.parse_args()
k = 10
env = KArmedBandit(
k=k, sigma=args.env_sigma, non_stationary=args.non_stationary_env
)
num_episodes = 100
fig, axs = plt.subplots(4, figsize=(7, 8))
axs[0].errorbar(np.arange(1, k + 1), env.mus, yerr=env.sigmas, xerr=0.1)
total_reward = []
with mp.Pool(mp.cpu_count()) as pool:
result = pool.map(
partial(
episode,
env=env,
num_actions=k,
epsilon=args.epsilon,
min_each_action=args.min_each_action,
interpolate_eps=args.interpolate_eps,
step_size=args.alpha,
),
range(1, num_episodes + 1),
)
for e in range(num_episodes):
trajectory, reward_collected = result[e]
axs[1].plot(reward_collected)
axs[2].plot(trajectory)
total_reward.append(reward_collected[-1])
reward_mean = np.mean(total_reward)
reward_std = np.std(total_reward)
print(f"Mean: {reward_mean}, Std: {reward_std}")
xs = np.arange(reward_mean - 5 * reward_std, reward_mean + 5 * reward_std, 5)
pdf = norm.pdf(xs, reward_mean, reward_std)
axs[3].plot(xs, pdf)
axs[3].vlines(x=reward_mean, ymin=0, ymax=pdf.max(), color="r")
plt.show()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment