Skip to content

Instantly share code, notes, and snippets.

@conormm
Created September 29, 2018 18:52
Show Gist options
  • Save conormm/c0e2e57afdafc047bf13bd5bf8ae1844 to your computer and use it in GitHub Desktop.
Save conormm/c0e2e57afdafc047bf13bd5bf8ae1844 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy.stats import beta
sns.set_style("whitegrid")
class Environment:
def __init__(self, variants, payouts, n_trials, variance=False):
self.variants = variants
if variance:
self.payouts = np.clip(payouts + np.random.normal(0, 0.04, size=len(variants)), 0, .2)
else:
self.payouts = payouts
#self.payouts[5] = self.payouts[5] if i < n_trials/2 else 0.1
self.n_trials = n_trials
self.total_reward = 0
self.n_k = len(variants)
self.shape = (self.n_k, n_trials)
def run(self, agent):
"""Run the simulation with the agent.
agent must be a class with choose_k and update methods."""
for i in range(self.n_trials):
# agent makes a choice
x_chosen = agent.choose_k()
# Environment returns reward
reward = np.random.binomial(1, p=self.payouts[x_chosen])
# agent learns of reward
agent.reward = reward
# agent updates parameters based on the data
agent.update()
self.total_reward += reward
agent.collect_data()
return self.total_reward
class BaseSampler:
def __init__(self, env, n_samples=None, n_learning=None, e=0.05):
self.env = env
self.shape = (env.n_k, n_samples)
self.variants = env.variants
self.n_trials = env.n_trials
self.payouts = env.payouts
self.ad_i = np.zeros(env.n_trials)
self.r_i = np.zeros(env.n_trials)
self.thetas = np.zeros(self.n_trials)
self.regret_i = np.zeros(env.n_trials)
self.thetaregret = np.zeros(self.n_trials)
self.a = np.ones(env.n_k)
self.b = np.ones(env.n_k)
self.theta = np.zeros(env.n_k)
self.data = None
self.reward = 0
self.total_reward = 0
self.k = 0
self.i = 0
self.n_samples = n_samples
self.n_learning = n_learning
self.e = e
self.ep = np.random.uniform(0, 1, size=env.n_trials)
self.exploit = (1 - e)
def collect_data(self):
self.data = pd.DataFrame(dict(ad=self.ad_i, reward=self.r_i, regret=self.regret_i))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment