This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ThompsonSampler: | |
"""Thompson Sampling using a Beta distribution associated with each option. | |
The beta distribution will be updated when rewards associated with each option | |
are observed. | |
""" | |
def __init__(self, env, n_learning=0): | |
# boilier plate data storage | |
self.env = env |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler | |
from torchvision.datasets.folder import ImageFolder, default_loader | |
from torchvision.datasets.utils import check_integrity | |
from torchvision import transforms | |
from torchvision import models | |
import matplotlib.pyplot as plt |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# usr/bin/env python3 | |
# author: Conor McDonald | |
# torch==0.4.1 | |
# numpy==1.14.3 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ThompsonSampler(BaseSampler): | |
def __init__(self, env): | |
super().__init__(env) | |
def choose_k(self): | |
# sample from posterior (this is the thompson sampling approach) | |
# this leads to more exploration because machines with > uncertainty can then be selected as the machine | |
self.theta = np.random.beta(self.a, self.b) | |
# select machine with highest posterior p of payout |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class eGreedy(BaseSampler): | |
def __init__(self, env, n_learning, e): | |
super().__init__(env, n_learning, e) | |
def choose_k(self): | |
# e% of the time take a random draw from machines | |
# random k for n learning trials, then the machine with highest theta | |
self.k = np.random.choice(self.variants) if self.i < self.n_learning else np.argmax(self.theta) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RandomSampler(BaseSampler): | |
def __init__(self, env): | |
super().__init__(env) | |
def choose_k(self): | |
self.k = np.random.choice(self.variants) | |
return self.k | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import pandas as pd | |
from scipy.stats import beta | |
sns.set_style("whitegrid") | |
class Environment: | |
def __init__(self, variants, payouts, n_trials, variance=False): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%load_ext cythonmagic | |
#%cython | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import pandas as pd | |
sns.set_style("whitegrid") | |
get_ipython().run_line_magic('matplotlib', 'inline') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import pandas as pd | |
sns.set_style("whitegrid") | |
get_ipython().run_line_magic('matplotlib', 'inline') | |
from IPython.core.display import display, HTML |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# In[150]: | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns |