Created
May 28, 2020 11:14
-
-
Save twolodzko/4bd10a490d49b9e4d31fd9a38eb089f6 to your computer and use it in GitHub Desktop.
Simplified multiarmed bandit simulation for https://stats.stackexchange.com/q/468724/35989
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from tqdm import trange\n", | |
| "import numpy as np\n", | |
| "import scipy.stats as sp\n", | |
| "import matplotlib.pyplot as plt" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD7CAYAAACG50QgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAdwUlEQVR4nO3df4wU95nn8fdjDgYT8PFr4IAZB/COOXsMHh8T4ijJmVsbM+vkjK2IE5H8S2eJ/GFLWcmJbG+C19HGku+y2dWhCyHkYnmy8sVB2uWMoiQ2ZtdrJyEmQwyYMWCGAcMw7Ay/HDz2ePDAc390DTRD9/SPqerq7vq8pFZ3V1dVP1/ofvo73/rWU+buiIhIMlwVdwAiIlI6SvoiIgmipC8ikiBK+iIiCaKkLyKSIEr6IiIJkjPpm9l4M9tuZrvMrN3MvhMsf9rMjpnZzuB2V9o2T5pZh5ntN7PlUTZARETyZ7nm6ZuZAZ9y9z4zGwv8Bvg60AL0ufvfDlv/RuBnwBJgNvAqcL27n48gfhERKcC/y7WCp34V+oKnY4PbSL8UK4AX3X0AOGRmHaR+ALZl22D69Ok+d+7cfGMWERFgx44dJ929tpBtciZ9ADMbA+wA/gz4gbu/aWZ/ATxqZg8AbcBj7n4GmAP8Pm3zrmBZVnPnzqWtra2QuEVEEs/M3it0m7wO5Lr7eXdvAuqAJWZ2E/BD4DqgCTgOfH8ojky7yBDsajNrM7O2EydOFBq3iIgUoaDZO+7+PvAa0OLuPcGPwQXgx6SGcCDVs69P26wO6M6wrw3u3uzuzbW1Bf11IiIiRcpn9k6tmU0OHl8N3AHsM7NZaavdC+wJHm8GVplZjZnNAxqA7eGGLSIixchnTH8W0BqM618FbHT3X5jZP5hZE6mhm8PA1wDcvd3MNgLvAIPAI5q5IyKj8cknn9DV1cXHH38cdyixGD9+PHV1dYwdO3bU+8o5ZbMUmpubXQdyRSSbQ4cOMWnSJKZNm0ZqFnlyuDunTp3igw8+YN68eZe9ZmY73L25kP3pjFwRKXsff/xxIhM+gJkxbdq00P7KUdIXkYqQxIQ/JMy2K+mLiMSstbWVhoYGGhoaaG1tjfS98jo5S0SknLzd9adQ97ew7t+Hsp/z588zZsyYgrY5ffo03/nOd2hra8PMWLx4MXfffTdTpkwJJabh1NMXEcnDPffcw+LFi2lsbGTDhg0Xl0+cOJGnnnqKz372s2zbto2JEyfy+OOPs3jxYu644w62b9/O0qVLmT9/Pps3b75ivy+//DLLli1j6tSpTJkyhWXLlvHrX/86snYo6YuI5OG5555jx44dtLW1sXbtWk6dOgXAhx9+yE033cSbb77JF77wBT788EOWLl3Kjh07mDRpEt/+9rfZsmULmzZt4qmnnrpiv8eOHaO+/tL5rHV1dRw7diyydmh4R0QkD2vXrmXTpk0AHD16lAMHDjBt2jTGjBnDV77ylYvrjRs3jpaWFgAWLlxITU0NY8eOZeHChRw+fPiK/WaaNh/lQWv19EVEcnjttdd49dVX2bZtG7t27eKWW265OIVy/Pjxl43jjx079mLSvuqqq6ipqbn4eHBw8Ip919XVcfTo0YvPu7q6mD17dmRtUdIXqUK9752NO4Sq8qc//YkpU6YwYcIE9u3bx+9///vcG+Vp+fLlvPLKK5w5c4YzZ87wyiuvsHx5dNee0vCOiEgOLS0trF+/nkWLFrFgwQJuvfXW0PY9depU1qxZw2c+8xkAnnrqKaZOnRra/odTGQaRkLWfaqdxWmMs79373llmfPqai/fVYu/evdxwww1xhxGrTP8GKsMgIiIjUtIXEUkQJX0RkQRR0hfJIOzT/EXKhZK+SA76AZBqoimbIjmMP7kbrpqYejL7lniDERkl9fRFqphO0qoMLS0tTJ48mS9/+cuRv5d6+iJSebrfCnd/If0FV0xpZYBvfvObfPTRR/zoRz8KJY6RqKcvkkW5jOX372mnf0973GEkXlSllQFuv/12Jk2aVJJ2KOmLhKj9lJJztYqqtHKpaXhHZMhlQwbzYwtDylNUpZVLLWdP38zGm9l2M9tlZu1m9p1g+VQz22JmB4L7KWnbPGlmHWa238yiKxcnErFyGeKReEVZWrnU8hneGQD+3N1vBpqAFjO7FXgC2OruDcDW4DlmdiOwCmgEWoB1Zlb4kQ2RStL9Vup2Yn/ckVykmTvhibK0cqnlTPqe0hc8HRvcHFgBDF22vRW4J3i8AnjR3Qfc/RDQASwJNWoRkRJqaWlhcHCQRYsWsWbNmlBLKwN88YtfZOXKlWzdupW6ujpefvnlUPefLq8x/aCnvgP4M+AH7v6mmc109+MA7n7czGYEq88B0n8Gu4JlIiLhKPFJcjU1NfzqV7/K+FpfX1/W508//fSI6w554403RhdgAfKavePu5929CagDlpjZTSOsnunijlcU7Tez1WbWZmZtJ06cyC9aEREZlYKmbLr7+8BrpMbqe8xsFkBw3xus1gXUp21WB3Rn2NcGd2929+ba2toiQhcJ34HezD2xjIbG8UM+UUhz8iVK+czeqTWzycHjq4E7gH3AZuDBYLUHgZeCx5uBVWZWY2bzgAZge9iBi5QzzdeXcpXPmP4soDUY178K2OjuvzCzbcBGM3sYOAKsBHD3djPbCLwDDAKPuPv5aMIXEZFC5Ez67r4buOKoibufAm7Pss0zwDOjjk4kLOlDMBVaKbN/TztX31TctXfTr5lbbdfPlcKoDIOISIIo6YvkIZ8DvO1nD6UenNifuoVdCRId5K1GO3fu5HOf+xyNjY0sWrSIn//855G+n2rvSLJFkJglemEfKG+cVtyw2XDFlFaeMGECP/3pT2loaKC7u5vFixezfPlyJk+eHEpMw6mnL1IGCunB5yq1rPIL0YiqtPL1119PQ0MDALNnz2bGjBlEee6Skr5ImVDd/PJWitLK27dv59y5c1x33XWRtUPDOyIhuDieHxPNyIle1KWVjx8/zv33309raytXXRVdf1xJX0Qkh/TSyhMmTGDp0qWhllY+e/YsX/rSl/jud78bejG34ZT0RaJUBecHSLSllc+dO8e9997LAw88wMqVK0PbbzYa0xfJU0F1ebIYPmafaxxfY/zlIcrSyhs3buT111/n+eefp6mpiaamJnbu3Bna/odTT1+kEJriWRbCmmKZryhLK993333cd999ow8yT+rpi5SYZulInJT0RUQSRMM7Up00DCOSkZK+SCCMA7VhKWb4J9uZuDpDV9Ip6Uvy5Pgr4EBvH0zP/lrDjImXLYv7xCyRQijpi3BlL3/8yd0xRSISLR3IFRGJ0XvvvcfixYtpamqisbGR9evXR/p+6umLJEi1jO+HPeW12CuSDVdMaeVZs2bxu9/9jpqaGvr6+rjpppu4++67mT17digxDaeevohIHqIqrTxu3LiL9XkGBga4cOFCpO1Q0hcpUDnN8pHSibK08tGjR1m0aBH19fU8/vjjkfXyQUlfRCQva9eu5eabb+bWW2+9WFoZyFla+bbbbstZWrm+vp7du3fT0dFBa2srPT09kbVDSV9kFEKZrnly36WblKX00sq7du3illtuCbW08pDZs2fT2NjIG2+8EVFL8kj6ZlZvZv9iZnvNrN3Mvh4sf9rMjpnZzuB2V9o2T5pZh5ntN7PlkUUvIlICUZZW7urqor+/H4AzZ87w29/+lgULFoS2/+Hymb0zCDzm7n80s0nADjPbErz29+7+t+krm9mNwCqgEZgNvGpm17v7+TADF4lTppO0cup+C052ph5P/4+Z10nv7WdbR0qupaWF9evXs2jRIhYsWBBqaeW9e/fy2GOPYWa4O9/4xjdYuHBhaPsfLmfSd/fjwPHg8QdmtheYM8ImK4AX3X0AOGRmHcASYFsI8YpI4Og/76T+z5uqZhpmIcKaYpmvKEsrL1u2jN27S3cyYEFj+mY2F7gFeDNY9KiZ7Taz58xsSrBsDnA0bbMuRv6REEmek/uCnv8I4/ga65cI5J30zWwi8I/AX7r7WeCHwHVAE6m/BL4/tGqGzT3D/labWZuZtZ04caLgwEUkdZLSwMHOuMOQCpJX0jezsaQS/gvu/k8A7t7j7ufd/QLwY1JDOJDq2denbV4HdA/fp7tvcPdmd2+ura0dTRskybrfunRLMCV+yVc+s3cM+Amw193/Lm35rLTV7gX2BI83A6vMrMbM5gENwPbwQhYJV1QnW5Vz9c0kHgeQlHxm73weuB9428yGrtb7V8BXzayJ1NDNYeBrAO7ebmYbgXdIzfx5RDN3pCRK3Ns/0NsH40v6liKjls/snd+QeZz+lyNs8wzwzCjiEpGQqXcvoDNyRUTKwtmzZ5kzZw6PPvpopO+j0soiFehk7ydxhxCrsP9qmfHpa0LZTzGllYesWbOG2267LZQ4RqKevohIHqIqrQywY8cOenp6uPPOOyNvh5K+SAn0v9tJ/7uaVlnJoiqtfOHCBR577DG+973vlaQdGt4REcnD2rVr2bRpE8DF0srTpk3LWVq5pqZmxNLK69at46677qK+vv6K16KgpC9SwZI+tl8q6aWVJ0yYwNKlS0Mrrbxt2zbeeOMN1q1bR19fH+fOnWPixIk8++yzkbRFSV9EJIcoSyu/8MILFx8///zztLW1RZbwQWP6IiI5tbS0MDg4yKJFi1izZk2opZVLTT19kSId/OgI9eOvLnp7HdgtXlhTLPMVZWnldA899BAPPfRQUTHmS0lfKk/Ci6tlM1R0rea6+Zc9Fkmn4R2RUTp6uj/uEC6jipsyEiV9kYRSLZ5kUtIXkYrgfsW1mBIjzLYr6YtI2Rs/fjynTp1KZOJ3d06dOsX48eHU8daBXBEpe3V1dXR1dZHUS6uOHz+eurq6UPalpC8iZW/s2LHMmzcv7jCqgoZ3RCKm+fhSTpT0RUQSRElfJARXzNXvOBbdm53cB+8fTt1ECqSkL1JhTp4u7spMIqADuSKjMuLZuB3H6J+YvCmGUt7U05dEO9A7cgGsaqezcpMnZ9I3s3oz+xcz22tm7Wb29WD5VDPbYmYHgvspads8aWYdZrbfzJZH2QCRRDi579JNZBTy6ekPAo+5+w3ArcAjZnYj8ASw1d0bgK3Bc4LXVgGNQAuwzsw0CCkiUgZyJn13P+7ufwwefwDsBeYAK4DWYLVW4J7g8QrgRXcfcPdDQAewJOzAJQG637p0qwRRztgRCUlBY/pmNhe4BXgTmOnuxyH1wwDMCFabAxxN26wrWCYiIjHLO+mb2UTgH4G/dPeRjv5YhmVXTGEws9Vm1mZmbUmtpyEiUmp5JX0zG0sq4b/g7v8ULO4xs1nB67OA3mB5F1Cftnkd0D18n+6+wd2b3b25tra22PhFilaKmTudfVd89EVilc/sHQN+Aux1979Le2kz8GDw+EHgpbTlq8ysxszmAQ3A9vBCFhGRYuVzctbngfuBt81sZ7Dsr4BngY1m9jBwBFgJ4O7tZrYReIfUzJ9H3P186JGLyOWlGCbPveLlgYOduk6uXCZn0nf335B5nB7g9izbPAM8M4q4REQkAjojVxIp6WfiSnKp9o5UhhDn6ivhS5Kppy8ikiDq6Ut5qZSzb0UqlHr6IhUiZx39oQurFHhxFVXaTBYlfRGRBNHwjkgRDn50JPMLKromZU49fZGQ9JwdoOfsQNxhiIxIPX2RClLU9XFP7L/0uHZBeMFIRVJPX0QkQdTTFxmlcYePxx2CSN7U0xepcgNHjjFwZOQDzJq2mRxK+iIJNHCwM+4QJCZK+iIiCaKkLyKSIEr6IiIJotk7ItWowPo7khzq6YuIJIiSvohIgmh4R+KnGvoiJaOevkhC6AQtgTySvpk9Z2a9ZrYnbdnTZnbMzHYGt7vSXnvSzDrMbL+ZLY8qcJFK0dnXHXcIIhfl09N/HmjJsPzv3b0puP0SwMxuBFYBjcE268ysiLKAIiIShZxJ391fB07nub8VwIvuPuDuh4AOYMko4hMJ1YHevrhDKEpRJZVFMhjNmP6jZrY7GP6ZEiybAxxNW6crWCYiZWDgYKfq7iRcsUn/h8B1QBNwHPh+sNwyrOuZdmBmq82szczaTpw4UWQYIvEYd/i4SipLRSoq6bt7j7ufd/cLwI+5NITTBdSnrVoHZDyK5e4b3L3Z3Ztra2uLCUNERApUVNI3s1lpT+8Fhmb2bAZWmVmNmc0DGoDtowtRqkb3W5duVaysr5P7/uFLt/TLKEpi5Dw5y8x+BiwFpptZF/DXwFIzayI1dHMY+BqAu7eb2UbgHWAQeMTdz0cTukg8Dn50hHHB42oY4hk42EnNdfPjDkNKJGfSd/evZlj8kxHWfwZ4ZjRBiYhINFSGQURSTuyHsUFKmH1LvLFIZJT0JVpVOH5fDUM6EJRlODs27jCkxJT0JR5V+GMgUglUcE1EJEGU9EXkMr3dg3GHIBFS0heRi5Twq5+SvohIgijpi4gkiGbviMjI0mdaaf5+xVNPX0QkQZT0RUQSRMM7InnyzvfyXrfn7AAzr6mJMBqR4qinL5KHQhK+SDlTT18kyd4/nLo/8UmsYUjpqKcviRD3BdE7+zJeQE6k5JT0JTHiTvwi5UBJX0QkQTSmLyJXUunrqqWkL1UvzGGdro//LbR9VSSdnVvxlPRFytjJ02OYPvV85O8zcOTYxcc1186J/P0kPkr6Ej4NDYiULR3IFSlzJ0+PiTsEqSI5k76ZPWdmvWa2J23ZVDPbYmYHgvspaa89aWYdZrbfzJZHFbiIiBQun57+80DLsGVPAFvdvQHYGjzHzG4EVgGNwTbrzEzdFImN5uaLXC5n0nf314HTwxavAFqDx63APWnLX3T3AXc/BHQAS0KKVaSk4q63E+ewji6bWL2KHdOf6e7HAYL7GcHyOcDRtPW6gmUiidNzdiDuEESuEPbsHcuwzDOuaLYaWA1w7bXXhhyGSDji7u2LhK3Ynn6Pmc0CCO57g+VdQH3aenVAxkpT7r7B3Zvdvbm2trbIMEREpBDFJv3NwIPB4weBl9KWrzKzGjObBzQA20cXooiIhCXn8I6Z/QxYCkw3sy7gr4FngY1m9jBwBFgJ4O7tZrYReAcYBB5x9+hPJ5T46YQskYqQM+m7+1ezvHR7lvWfAZ4ZTVAiIhINnZErIpIgSvoiIgmipC9SIrpkopQDVdmU4ungbU49ZweYeU1N3GEUpbd7kBmzlSKqjXr6ImWoHCprptfYl+qhpC8SMZVjkHKipC8ikiBK+lKVVFJZJDMdpZHcdDFsyUSfi4qkpC8il9EB3Oqm4R0RkQRRT18KU6Vz84fq5tv8T8ccSXlM15TqpaQvMkycF04pt4Tf/24nAFdfPz/mSCQsGt4REUkQ9fTlkiqbjVHMtE1dHvFyJ0+PYfpUXRKjmqinLyKSIOrpS2ZVesBWJOnU0xfJU9fH/1b0tj1nB1SDR8qCevoiEq4qOzZUbZT0JfFKffC2s6+b+RNnl/Q9w9T/bueVUzjzGQ7Uj0FZUNIXKaFKvHrWydNjqC92Yx0bKjsa0xeRnIZO0pLKN6qevpkdBj4AzgOD7t5sZlOBnwNzgcPAf3P3M6MLU6T6ldvZuFKdwujp/xd3b3L35uD5E8BWd28AtgbPRUSkDEQxvLMCaA0etwL3RPAeIlnpAioi2Y026TvwipntMLPVwbKZ7n4cILifMcr3EKka2ebqa2hHSmW0s3c+7+7dZjYD2GJm+/LdMPiRWA1w7bXXjjIMKZpmV4gkyqh6+u7eHdz3ApuAJUCPmc0CCO57s2y7wd2b3b25trZ2NGGIiEieik76ZvYpM5s09Bi4E9gDbAYeDFZ7EHhptEGKRME731NVTUmc0QzvzAQ2mdnQfv6vu//azP4AbDSzh4EjwMrRhykicdNc/epQdNJ3907g5gzLTwG3jyYoiVgVj+PnmrlT7GURR1NsrRoMHWhWbf3KpzIMSVHFiV5E8qcyDFI1ND9fJDclfZEcohjaOXq6P/R9llL/u50a469QGt4RkdJTmeXYqKcvVaHShnaSfmBY4qOkL4mk+fnFU8mIyqbhnWqmGTtlqZqulZvxKlpS1pT0K93wxK7x0Ypy8vSYqpr7PnRwt+gfAo31R05JX6qehnJELlHSrzYa0rmMEn7paApnZVDSF4lZ4g+MqqNSUpq9IyJ5S/wPVBVQ0peKF+Uc/Sjn07cfHoxs3yLZaHhHJEY9ZweYeU1N3GEULH3WUUnG8jWrJzRK+uVMH3SpYKOevimRUNKvFDrYVZByn7XT/8HVpHfwK7XHH6l8PvPqGBVMY/oiMeoZOFnRZ+jqwG7lUU9fKlpUB3FLURCtZ+Bk5O8RpXwTfuilGvRX76go6UtV8M73Cr4EYqn1f3A1AFdPquxa+vmI/UQtDftkpaQfF30oRSQGSvrlQH+u5uVAbx8NMyZefDxcOR+8Herlj6SSD+aWdeE4dbAuowO5UlGiTvalGMsf/gNQTWP7ucb504d9YhkC6n7r0i2hIuvpm1kL8L+AMcD/cfdno3ovqS7DD/yl9/CHnkN59+yLVcm9/ZFkS/ahzuVPcCIvRCRJ38zGAD8AlgFdwB/MbLO7vxPF+5UdffhC0/9uJ0yewYHevssO1kaZ8OO8lGHPwElmMqfik3/6cE8+Qz/Df+hLcnGWfL+nVTYkFFVPfwnQ4e6dAGb2IrACqMykn8+YoBJ96A709lFH5fbo+z+4mqsn9ec1ng+Xhnkyzduvph+Bipbte15BPwxRJf05wNG0513AZyN6r/wSbvp/SrYkruReEkN/0ndNnkHd+71A6s/7/nc7OXo6NZ2x/t1O/HT/ZR8iiOYHIN+e/VASz7QsW4IvNOGn233iGDNrpl9aJ+3HoBx/AIbG89PvC0n0w8f4058P9frT/wLI9HokwvrOl8kPRlRJ3zIs88tWMFsNrA6e9pnZ/ohiAZgOVPbRskuqqS1QXe1RW8pTNbUFLm9PwSenRJX0u4D6tOd1QHf6Cu6+AdgQ0ftfxsza3L25FO8VtWpqC1RXe9SW8lRNbYHRtyeqKZt/ABrMbJ6ZjQNWAZsjei8REclTJD19dx80s0eBl0lN2XzO3dujeC8REclfZPP03f2XwC+j2n+BSjKMVCLV1BaorvaoLeWpmtoCo2yPuXvutUREpCqoDIOISIJUVdI3s5Vm1m5mF8ysOW35XDPrN7OdwW192muLzextM+sws7Vmlmm6aSyytSd47ckg5v1mtjxtedm2Z4iZPW1mx9L+P+5Key1ju8qZmbUE8XaY2RNxx1MoMzscfGZ2mllbsGyqmW0xswPB/ZS448zGzJ4zs14z25O2LGv85fwZy9KWcL8v7l41N+AGYAHwGtCctnwusCfLNtuBz5E6t+BXwF/E3Y482nMjsAuoAeYBB4Ex5d6etPifBr6RYXnWdpXrjdREhYPAfGBcEP+NccdVYBsOA9OHLfufwBPB4yeA/xF3nCPE/5+B/5T+Hc8Wf7l/xrK0JdTvS1X19N19r7vnfZKXmc0CrnH3bZ76V/wpcE9kARZohPasAF509wF3PwR0AEvKvT15yNiumGPK5WLJEXc/BwyVHKl0K4DW4HErZfw5cvfXgdPDFmeLv6w/Y1nakk1RbamqpJ/DPDN7y8z+1cy+GCybQ+pEsiFdwbJyl6nMxRwqqz2Pmtnu4M/ZoT+9s7WrnFVizMM58IqZ7QjOlAeY6e7HAYL7GbFFV5xs8Vfq/1do35eKu4iKmb0K/IcML33L3V/Kstlx4Fp3P2Vmi4H/Z2aN5FEuImpFtidb3LG3Z8hI7QJ+CPwNqdj+Bvg+8N8po/gLUIkxD/d5d+82sxnAFjPbF3dAEarE/69Qvy8Vl/Td/Y4ithkABoLHO8zsIHA9qV/GurRVrygXEbVi2kP2Mhext2dIvu0ysx8Dvwie5izfUYYqMebLuHt3cN9rZptIDRH0mNksdz8eDBv2xhpk4bLFX3H/X+7eM/Q4jO9LIoZ3zKw2qPGPmc0HGoDO4M++D8zs1mCWywNAtt51OdkMrDKzGjObR6o92yulPcGXcMi9wNBMhYztKnV8BarokiNm9ikzmzT0GLiT1P/HZuDBYLUHKcPPUQ7Z4q+4z1jo35e4j1aHfOT7XlK/fgNAD/BysPwrQDupI91/BP5r2jbNwT/iQeB/E5ywVg63bO0JXvtWEPN+0mbolHN70mL8B+BtYHfwwZ2Vq13lfAPuAt4N4v5W3PEUGPv84HuxK/iOfCtYPg3YChwI7qfGHesIbfgZqSHcT4Lvy8MjxV/On7EsbQn1+6IzckVEEiQRwzsiIpKipC8ikiBK+iIiCaKkLyKSIEr6IiIJoqQvIpIgSvoiIgmipC8ikiD/H2UEAoTNazCCAAAAAElFTkSuQmCC\n", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "mu = [-10, -5, 1, 20, 35]\n", | |
| "sigma = [10, 40, 5, 20, 10]\n", | |
| "\n", | |
| "def multi_normal(mu, sigma, size=None):\n", | |
| " assert len(mu) == len(sigma)\n", | |
| " return np.stack([\n", | |
| " sp.norm(loc=mu[i], scale=sigma[i]).rvs(size=size)\n", | |
| " for i in range(len(mu))\n", | |
| " ], axis=1)\n", | |
| "\n", | |
| "def plot_dist(samples, **kwargs):\n", | |
| " for i in range(samples.shape[1]):\n", | |
| " plt.hist(samples[:, i], label=f'arm {i}', **kwargs)\n", | |
| " plt.legend()\n", | |
| " \n", | |
| "def draw_rewards(mu, sigma, size=None):\n", | |
| " return (multi_normal(mu, sigma, size=size) > 0.5).astype(int)\n", | |
| "\n", | |
| "plot_dist(multi_normal(mu, sigma, size=10_000), bins=100, alpha=0.2)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(array([ 3., 3., 4., 10., 10.]), array([0, 3, 3, 3, 3, 3, 3, 3, 3, 3]), 9)" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "def simulate_game(n, mu, sigma, strategy='maxwin'):\n", | |
| " \n", | |
| " rewards = np.zeros(shape=(n, len(mu)))\n", | |
| " choices = []\n", | |
| " collected_rewards = 0\n", | |
| " \n", | |
| " for i in range(n):\n", | |
| " \n", | |
| " if i == 0 or strategy == 'random':\n", | |
| " choice = np.random.randint(len(mu))\n", | |
| " if strategy == 'maxwin':\n", | |
| " choice = np.argmax(np.sum(rewards[:i, :], axis=0))\n", | |
| " elif strategy == 'meansd':\n", | |
| " choice = np.argmax(np.mean(rewards[:i, :], axis=0) / np.std(rewards[:i, :], axis=0))\n", | |
| " \n", | |
| " result = draw_rewards(mu, sigma, 1)[0, :]\n", | |
| " \n", | |
| " rewards[i, :] = result\n", | |
| " collected_rewards += result[choice]\n", | |
| " choices.append(choice)\n", | |
| " \n", | |
| " return np.sum(rewards, axis=0), np.array(choices), collected_rewards \n", | |
| "\n", | |
| "\n", | |
| "simulate_game(10, mu, sigma)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 1000/1000 [11:18<00:00, 1.47it/s]\n", | |
| " 0%| | 0/1000 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "random: 22.556 vs 23.031 vs 29.829\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 1000/1000 [11:13<00:00, 1.49it/s]\n", | |
| " 0%| | 0/1000 [00:00<?, ?it/s]/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice.\n", | |
| " out=out, **kwargs)\n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:154: RuntimeWarning: invalid value encountered in true_divide\n", | |
| " ret, rcount, out=ret, casting='unsafe', subok=False)\n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:217: RuntimeWarning: Degrees of freedom <= 0 for slice\n", | |
| " keepdims=keepdims)\n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:186: RuntimeWarning: invalid value encountered in true_divide\n", | |
| " arrmean, rcount, out=arrmean, casting='unsafe', subok=False)\n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:207: RuntimeWarning: invalid value encountered in true_divide\n", | |
| " ret, rcount, out=ret, casting='unsafe', subok=False)\n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:14: RuntimeWarning: divide by zero encountered in true_divide\n", | |
| " \n", | |
| "/home/tymek/.miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:14: RuntimeWarning: invalid value encountered in true_divide\n", | |
| " \n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "maxwin: 23.251 vs 22.885 vs 29.782\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "100%|██████████| 1000/1000 [11:33<00:00, 1.44it/s]" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "meansd: 23.166 vs 22.885 vs 29.782\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "R = 1000\n", | |
| "N = 50\n", | |
| "k = 20\n", | |
| "\n", | |
| "for strategy in ['random', 'maxwin', 'meansd']:\n", | |
| " \n", | |
| " rewards = 0 # actual rewards given the choosen strategy\n", | |
| " oracle = 0 # best strategy if we hade an oracle, that knew the future outcome\n", | |
| " bestmu = 0 # choose arm based on the known, true expected reward\n", | |
| " \n", | |
| " for i in trange(R):\n", | |
| " np.random.seed(i)\n", | |
| " m = sp.uniform(-10, 2).rvs(k)\n", | |
| " s = sp.uniform(20, 200).rvs(k)\n", | |
| " \n", | |
| " total_rewards, _, collected_rewards = simulate_game(N, m, s, strategy=strategy)\n", | |
| " \n", | |
| " rewards += collected_rewards\n", | |
| " oracle += np.max(total_rewards)\n", | |
| " bestmu += total_rewards[np.argmax(m)]\n", | |
| " \n", | |
| " print(f'{strategy}: {rewards / R} vs {bestmu / R} vs {oracle / R}')" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 4 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment