Skip to content

Instantly share code, notes, and snippets.

@twolodzko
Created May 28, 2020 11:14
Show Gist options
  • Select an option

  • Save twolodzko/4bd10a490d49b9e4d31fd9a38eb089f6 to your computer and use it in GitHub Desktop.

Select an option

Save twolodzko/4bd10a490d49b9e4d31fd9a38eb089f6 to your computer and use it in GitHub Desktop.
Simplified multiarmed bandit simulation for https://stats.stackexchange.com/q/468724/35989
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from tqdm import trange\n",
"import numpy as np\n",
"import scipy.stats as sp\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"mu = [-10, -5, 1, 20, 35]\n",
"sigma = [10, 40, 5, 20, 10]\n",
"\n",
"def multi_normal(mu, sigma, size=None):\n",
" assert len(mu) == len(sigma)\n",
" return np.stack([\n",
" sp.norm(loc=mu[i], scale=sigma[i]).rvs(size=size)\n",
" for i in range(len(mu))\n",
" ], axis=1)\n",
"\n",
"def plot_dist(samples, **kwargs):\n",
" for i in range(samples.shape[1]):\n",
" plt.hist(samples[:, i], label=f'arm {i}', **kwargs)\n",
" plt.legend()\n",
" \n",
"def draw_rewards(mu, sigma, size=None):\n",
" return (multi_normal(mu, sigma, size=size) > 0.5).astype(int)\n",
"\n",
"plot_dist(multi_normal(mu, sigma, size=10_000), bins=100, alpha=0.2)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 3., 3., 4., 10., 10.]), array([0, 3, 3, 3, 3, 3, 3, 3, 3, 3]), 9)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def simulate_game(n, mu, sigma, strategy='maxwin'):\n",
" \n",
" rewards = np.zeros(shape=(n, len(mu)))\n",
" choices = []\n",
" collected_rewards = 0\n",
" \n",
" for i in range(n):\n",
" \n",
" if i == 0 or strategy == 'random':\n",
" choice = np.random.randint(len(mu))\n",
" if strategy == 'maxwin':\n",
" choice = np.argmax(np.sum(rewards[:i, :], axis=0))\n",
" elif strategy == 'meansd':\n",
" choice = np.argmax(np.mean(rewards[:i, :], axis=0) / np.std(rewards[:i, :], axis=0))\n",
" \n",
" result = draw_rewards(mu, sigma, 1)[0, :]\n",
" \n",
" rewards[i, :] = result\n",
" collected_rewards += result[choice]\n",
" choices.append(choice)\n",
" \n",
" return np.sum(rewards, axis=0), np.array(choices), collected_rewards \n",
"\n",
"\n",
"simulate_game(10, mu, sigma)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [11:18<00:00, 1.47it/s]\n",
" 0%| | 0/1000 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"random: 22.556 vs 23.031 vs 29.829\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [11:13<00:00, 1.49it/s]\n",
" 0%| | 0/1000 [00:00<?, ?it/s]/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice.\n",
" out=out, **kwargs)\n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:154: RuntimeWarning: invalid value encountered in true_divide\n",
" ret, rcount, out=ret, casting='unsafe', subok=False)\n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:217: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
" keepdims=keepdims)\n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:186: RuntimeWarning: invalid value encountered in true_divide\n",
" arrmean, rcount, out=arrmean, casting='unsafe', subok=False)\n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:207: RuntimeWarning: invalid value encountered in true_divide\n",
" ret, rcount, out=ret, casting='unsafe', subok=False)\n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:14: RuntimeWarning: divide by zero encountered in true_divide\n",
" \n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:14: RuntimeWarning: invalid value encountered in true_divide\n",
" \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"maxwin: 23.251 vs 22.885 vs 29.782\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [11:33<00:00, 1.44it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"meansd: 23.166 vs 22.885 vs 29.782\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"R = 1000\n",
"N = 50\n",
"k = 20\n",
"\n",
"for strategy in ['random', 'maxwin', 'meansd']:\n",
" \n",
" rewards = 0 # actual rewards given the choosen strategy\n",
" oracle = 0 # best strategy if we hade an oracle, that knew the future outcome\n",
" bestmu = 0 # choose arm based on the known, true expected reward\n",
" \n",
" for i in trange(R):\n",
" np.random.seed(i)\n",
" m = sp.uniform(-10, 2).rvs(k)\n",
" s = sp.uniform(20, 200).rvs(k)\n",
" \n",
" total_rewards, _, collected_rewards = simulate_game(N, m, s, strategy=strategy)\n",
" \n",
" rewards += collected_rewards\n",
" oracle += np.max(total_rewards)\n",
" bestmu += total_rewards[np.argmax(m)]\n",
" \n",
" print(f'{strategy}: {rewards / R} vs {bestmu / R} vs {oracle / R}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment