Created
May 28, 2020 11:14
-
-
Save twolodzko/4bd10a490d49b9e4d31fd9a38eb089f6 to your computer and use it in GitHub Desktop.
Simplified multiarmed bandit simulation for https://stats.stackexchange.com/q/468724/35989
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from tqdm import trange\n", | |
| "import numpy as np\n", | |
| "import scipy.stats as sp\n", | |
| "import matplotlib.pyplot as plt" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "mu = [-10, -5, 1, 20, 35]\n", | |
| "sigma = [10, 40, 5, 20, 10]\n", | |
| "\n", | |
| "def multi_normal(mu, sigma, size=None):\n", | |
| " assert len(mu) == len(sigma)\n", | |
| " return np.stack([\n", | |
| " sp.norm(loc=mu[i], scale=sigma[i]).rvs(size=size)\n", | |
| " for i in range(len(mu))\n", | |
| " ], axis=1)\n", | |
| "\n", | |
| "def plot_dist(samples, **kwargs):\n", | |
| " for i in range(samples.shape[1]):\n", | |
| " plt.hist(samples[:, i], label=f'arm {i}', **kwargs)\n", | |
| " plt.legend()\n", | |
| " \n", | |
| "def draw_rewards(mu, sigma, size=None):\n", | |
| " return (multi_normal(mu, sigma, size=size) > 0.5).astype(int)\n", | |
| "\n", | |
| "plot_dist(multi_normal(mu, sigma, size=10_000), bins=100, alpha=0.2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(array([ 3., 3., 4., 10., 10.]), array([0, 3, 3, 3, 3, 3, 3, 3, 3, 3]), 9)" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "def simulate_game(n, mu, sigma, strategy='maxwin'):\n", | |
| " \n", | |
| " rewards = np.zeros(shape=(n, len(mu)))\n", | |
| " choices = []\n", | |
| " collected_rewards = 0\n", | |
| " \n", | |
| " for i in range(n):\n", | |
| " \n", | |
| " if i == 0 or strategy == 'random':\n", | |
| " choice = np.random.randint(len(mu))\n", | |
| " if strategy == 'maxwin':\n", | |
| " choice = np.argmax(np.sum(rewards[:i, :], axis=0))\n", | |
| " elif strategy == 'meansd':\n", | |
| " choice = np.argmax(np.mean(rewards[:i, :], axis=0) / np.std(rewards[:i, :], axis=0))\n", | |
| " \n", | |
| " result = draw_rewards(mu, sigma, 1)[0, :]\n", | |
| " \n", | |
| " rewards[i, :] = result\n", | |
| " collected_rewards += result[choice]\n", | |
| " choices.append(choice)\n", | |
| " \n", | |
| " return np.sum(rewards, axis=0), np.array(choices), collected_rewards \n", | |
| "\n", | |
| "\n", | |
| "simulate_game(10, mu, sigma)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 1000/1000 [11:18<00:00, 1.47it/s]\n", | |
| " 0%| | 0/1000 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "random: 22.556 vs 23.031 vs 29.829\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 1000/1000 [11:13<00:00, 1.49it/s]\n", | |
| " 0%| | 0/1000 [00:00<?, ?it/s]/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice.\n", | |
| " out=out, **kwargs)\n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:154: RuntimeWarning: invalid value encountered in true_divide\n", | |
| " ret, rcount, out=ret, casting='unsafe', subok=False)\n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:217: RuntimeWarning: Degrees of freedom <= 0 for slice\n", | |
| " keepdims=keepdims)\n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:186: RuntimeWarning: invalid value encountered in true_divide\n", | |
| " arrmean, rcount, out=arrmean, casting='unsafe', subok=False)\n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:207: RuntimeWarning: invalid value encountered in true_divide\n", | |
| " ret, rcount, out=ret, casting='unsafe', subok=False)\n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:14: RuntimeWarning: divide by zero encountered in true_divide\n", | |
| " \n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:14: RuntimeWarning: invalid value encountered in true_divide\n", | |
| " \n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "maxwin: 23.251 vs 22.885 vs 29.782\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 1000/1000 [11:33<00:00, 1.44it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "meansd: 23.166 vs 22.885 vs 29.782\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "R = 1000\n", | |
| "N = 50\n", | |
| "k = 20\n", | |
| "\n", | |
| "for strategy in ['random', 'maxwin', 'meansd']:\n", | |
| " \n", | |
| " rewards = 0 # actual rewards given the choosen strategy\n", | |
| " oracle = 0 # best strategy if we hade an oracle, that knew the future outcome\n", | |
| " bestmu = 0 # choose arm based on the known, true expected reward\n", | |
| " \n", | |
| " for i in trange(R):\n", | |
| " np.random.seed(i)\n", | |
| " m = sp.uniform(-10, 2).rvs(k)\n", | |
| " s = sp.uniform(20, 200).rvs(k)\n", | |
| " \n", | |
| " total_rewards, _, collected_rewards = simulate_game(N, m, s, strategy=strategy)\n", | |
| " \n", | |
| " rewards += collected_rewards\n", | |
| " oracle += np.max(total_rewards)\n", | |
| " bestmu += total_rewards[np.argmax(m)]\n", | |
| " \n", | |
| " print(f'{strategy}: {rewards / R} vs {bestmu / R} vs {oracle / R}')" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment