Skip to content

Instantly share code, notes, and snippets.

@twolodzko
Created May 28, 2020 11:14
Show Gist options
  • Select an option

  • Save twolodzko/4bd10a490d49b9e4d31fd9a38eb089f6 to your computer and use it in GitHub Desktop.

Select an option

Save twolodzko/4bd10a490d49b9e4d31fd9a38eb089f6 to your computer and use it in GitHub Desktop.
Simplified multiarmed bandit simulation for https://stats.stackexchange.com/q/468724/35989
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from tqdm import trange\n",
"import numpy as np\n",
"import scipy.stats as sp\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD7CAYAAACG50QgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAdwUlEQVR4nO3df4wU95nn8fdjDgYT8PFr4IAZB/COOXsMHh8T4ijJmVsbM+vkjK2IE5H8S2eJ/GFLWcmJbG+C19HGku+y2dWhCyHkYnmy8sVB2uWMoiQ2ZtdrJyEmQwyYMWCGAcMw7Ay/HDz2ePDAc390DTRD9/SPqerq7vq8pFZ3V1dVP1/ofvo73/rWU+buiIhIMlwVdwAiIlI6SvoiIgmipC8ikiBK+iIiCaKkLyKSIEr6IiIJkjPpm9l4M9tuZrvMrN3MvhMsf9rMjpnZzuB2V9o2T5pZh5ntN7PlUTZARETyZ7nm6ZuZAZ9y9z4zGwv8Bvg60AL0ufvfDlv/RuBnwBJgNvAqcL27n48gfhERKcC/y7WCp34V+oKnY4PbSL8UK4AX3X0AOGRmHaR+ALZl22D69Ok+d+7cfGMWERFgx44dJ929tpBtciZ9ADMbA+wA/gz4gbu/aWZ/ATxqZg8AbcBj7n4GmAP8Pm3zrmBZVnPnzqWtra2QuEVEEs/M3it0m7wO5Lr7eXdvAuqAJWZ2E/BD4DqgCTgOfH8ojky7yBDsajNrM7O2EydOFBq3iIgUoaDZO+7+PvAa0OLuPcGPwQXgx6SGcCDVs69P26wO6M6wrw3u3uzuzbW1Bf11IiIiRcpn9k6tmU0OHl8N3AHsM7NZaavdC+wJHm8GVplZjZnNAxqA7eGGLSIixchnTH8W0BqM618FbHT3X5jZP5hZE6mhm8PA1wDcvd3MNgLvAIPAI5q5IyKj8cknn9DV1cXHH38cdyixGD9+PHV1dYwdO3bU+8o5ZbMUmpubXQdyRSSbQ4cOMWnSJKZNm0ZqFnlyuDunTp3igw8+YN68eZe9ZmY73L25kP3pjFwRKXsff/xxIhM+gJkxbdq00P7KUdIXkYqQxIQ/JMy2K+mLiMSstbWVhoYGGhoaaG1tjfS98jo5S0SknLzd9adQ97ew7t+Hsp/z588zZsyYgrY5ffo03/nOd2hra8PMWLx4MXfffTdTpkwJJabh1NMXEcnDPffcw+LFi2lsbGTDhg0Xl0+cOJGnnnqKz372s2zbto2JEyfy+OOPs3jxYu644w62b9/O0qVLmT9/Pps3b75ivy+//DLLli1j6tSpTJkyhWXLlvHrX/86snYo6YuI5OG5555jx44dtLW1sXbtWk6dOgXAhx9+yE033cSbb77JF77wBT788EOWLl3Kjh07mDRpEt/+9rfZsmULmzZt4qmnnrpiv8eOHaO+/tL5rHV1dRw7diyydmh4R0QkD2vXrmXTpk0AHD16lAMHDjBt2jTGjBnDV77ylYvrjRs3jpaWFgAWLlxITU0NY8eOZeHChRw+fPiK/WaaNh/lQWv19EVEcnjttdd49dVX2bZtG7t27eKWW265OIVy/Pjxl43jjx079mLSvuqqq6ipqbn4eHBw8Ip919XVcfTo0YvPu7q6mD17dmRtUdIXqUK9752NO4Sq8qc//YkpU6YwYcIE9u3bx+9///vcG+Vp+fLlvPLKK5w5c4YzZ87wyiuvsHx5dNee0vCOiEgOLS0trF+/nkWLFrFgwQJuvfXW0PY9depU1qxZw2c+8xkAnnrqKaZOnRra/odTGQaRkLWfaqdxWmMs79373llmfPqai/fVYu/evdxwww1xhxGrTP8GKsMgIiIjUtIXEUkQJX0RkQRR0hfJIOzT/EXKhZK+SA76AZBqoimbIjmMP7kbrpqYejL7lniDERkl9fRFqphO0qoMLS0tTJ48mS9/+cuRv5d6+iJSebrfCnd/If0FV0xpZYBvfvObfPTRR/zoRz8KJY6RqKcvkkW5jOX372mnf0973GEkXlSllQFuv/12Jk2aVJJ2KOmLhKj9lJJztYqqtHKpaXhHZMhlQwbzYwtDylNUpZVLLWdP38zGm9l2M9tlZu1m9p1g+VQz22JmB4L7KWnbPGlmHWa238yiKxcnErFyGeKReEVZWrnU8hneGQD+3N1vBpqAFjO7FXgC2OruDcDW4DlmdiOwCmgEWoB1Zlb4kQ2RStL9Vup2Yn/ckVykmTvhibK0cqnlTPqe0hc8HRvcHFgBDF22vRW4J3i8AnjR3Qfc/RDQASwJNWoRkRJqaWlhcHCQRYsWsWbNmlBLKwN88YtfZOXKlWzdupW6ujpefvnlUPefLq8x/aCnvgP4M+AH7v6mmc109+MA7n7czGYEq88B0n8Gu4JlIiLhKPFJcjU1NfzqV7/K+FpfX1/W508//fSI6w554403RhdgAfKavePu5929CagDlpjZTSOsnunijlcU7Tez1WbWZmZtJ06cyC9aEREZlYKmbLr7+8BrpMbqe8xsFkBw3xus1gXUp21WB3Rn2NcGd2929+ba2toiQhcJ34HezD2xjIbG8UM+UUhz8iVK+czeqTWzycHjq4E7gH3AZuDBYLUHgZeCx5uBVWZWY2bzgAZge9iBi5QzzdeXcpXPmP4soDUY178K2OjuvzCzbcBGM3sYOAKsBHD3djPbCLwDDAKPuPv5aMIXEZFC5Ez67r4buOKoibufAm7Pss0zwDOjjk4kLOlDMBVaKbN/TztX31TctXfTr5lbbdfPlcKoDIOISIIo6YvkIZ8DvO1nD6UenNifuoVdCRId5K1GO3fu5HOf+xyNjY0sWrSIn//855G+n2rvSLJFkJglemEfKG+cVtyw2XDFlFaeMGECP/3pT2loaKC7u5vFixezfPlyJk+eHEpMw6mnL1IGCunB5yq1rPIL0YiqtPL1119PQ0MDALNnz2bGjBlEee6Skr5ImVDd/PJWitLK27dv59y5c1x33XWRtUPDOyIhuDieHxPNyIle1KWVjx8/zv33309raytXXRVdf1xJX0Qkh/TSyhMmTGDp0qWhllY+e/YsX/rSl/jud78bejG34ZT0RaJUBecHSLSllc+dO8e9997LAw88wMqVK0PbbzYa0xfJU0F1ebIYPmafaxxfY/zlIcrSyhs3buT111/n+eefp6mpiaamJnbu3Bna/odTT1+kEJriWRbCmmKZryhLK993333cd999ow8yT+rpi5SYZulInJT0RUQSRMM7Up00DCOSkZK+SCCMA7VhKWb4J9uZuDpDV9Ip6Uvy5Pgr4EBvH0zP/lrDjImXLYv7xCyRQijpi3BlL3/8yd0xRSISLR3IFRGJ0XvvvcfixYtpamqisbGR9evXR/p+6umLJEi1jO+HPeW12CuSDVdMaeVZs2bxu9/9jpqaGvr6+rjpppu4++67mT17digxDaeevohIHqIqrTxu3LiL9XkGBga4cOFCpO1Q0hcpUDnN8pHSibK08tGjR1m0aBH19fU8/vjjkfXyQUlfRCQva9eu5eabb+bWW2+9WFoZyFla+bbbbstZWrm+vp7du3fT0dFBa2srPT09kbVDSV9kFEKZrnly36WblKX00sq7du3illtuCbW08pDZs2fT2NjIG2+8EVFL8kj6ZlZvZv9iZnvNrN3Mvh4sf9rMjpnZzuB2V9o2T5pZh5ntN7PlkUUvIlICUZZW7urqor+/H4AzZ87w29/+lgULFoS2/+Hymb0zCDzm7n80s0nADjPbErz29+7+t+krm9mNwCqgEZgNvGpm17v7+TADF4lTppO0cup+C052ph5P/4+Z10nv7WdbR0qupaWF9evXs2jRIhYsWBBqaeW9e/fy2GOPYWa4O9/4xjdYuHBhaPsfLmfSd/fjwPHg8QdmtheYM8ImK4AX3X0AOGRmHcASYFsI8YpI4Og/76T+z5uqZhpmIcKaYpmvKEsrL1u2jN27S3cyYEFj+mY2F7gFeDNY9KiZ7Taz58xsSrBsDnA0bbMuRv6REEmek/uCnv8I4/ga65cI5J30zWwi8I/AX7r7WeCHwHVAE6m/BL4/tGqGzT3D/labWZuZtZ04caLgwEUkdZLSwMHOuMOQCpJX0jezsaQS/gvu/k8A7t7j7ufd/QLwY1JDOJDq2denbV4HdA/fp7tvcPdmd2+ura0dTRskybrfunRLMCV+yVc+s3cM+Amw193/Lm35rLTV7gX2BI83A6vMrMbM5gENwPbwQhYJV1QnW5Vz9c0kHgeQlHxm73weuB9428yGrtb7V8BXzayJ1NDNYeBrAO7ebmYbgXdIzfx5RDN3pCRK3Ns/0NsH40v6liKjls/snd+QeZz+lyNs8wzwzCjiEpGQqXcvoDNyRUTKwtmzZ5kzZw6PPvpopO+j0soiFehk7ydxhxCrsP9qmfHpa0LZTzGllYesWbOG2267LZQ4RqKevohIHqIqrQywY8cOenp6uPPOOyNvh5K+SAn0v9tJ/7uaVlnJoiqtfOHCBR577DG+973vlaQdGt4REcnD2rVr2bRpE8DF0srTpk3LWVq5pqZmxNLK69at46677qK+vv6K16KgpC9SwZI+tl8q6aWVJ0yYwNKlS0Mrrbxt2zbeeOMN1q1bR19fH+fOnWPixIk8++yzkbRFSV9EJIcoSyu/8MILFx8///zztLW1RZbwQWP6IiI5tbS0MDg4yKJFi1izZk2opZVLTT19kSId/OgI9eOvLnp7HdgtXlhTLPMVZWnldA899BAPPfRQUTHmS0lfKk/Ci6tlM1R0rea6+Zc9Fkmn4R2RUTp6uj/uEC6jipsyEiV9kYRSLZ5kUtIXkYrgfsW1mBIjzLYr6YtI2Rs/fjynTp1KZOJ3d06dOsX48eHU8daBXBEpe3V1dXR1dZHUS6uOHz+eurq6UPalpC8iZW/s2LHMmzcv7jCqgoZ3RCKm+fhSTpT0RUQSRElfJARXzNXvOBbdm53cB+8fTt1ECqSkL1JhTp4u7spMIqADuSKjMuLZuB3H6J+YvCmGUt7U05dEO9A7cgGsaqezcpMnZ9I3s3oz+xcz22tm7Wb29WD5VDPbYmYHgvspads8aWYdZrbfzJZH2QCRRDi579JNZBTy6ekPAo+5+w3ArcAjZnYj8ASw1d0bgK3Bc4LXVgGNQAuwzsw0CCkiUgZyJn13P+7ufwwefwDsBeYAK4DWYLVW4J7g8QrgRXcfcPdDQAewJOzAJQG637p0qwRRztgRCUlBY/pmNhe4BXgTmOnuxyH1wwDMCFabAxxN26wrWCYiIjHLO+mb2UTgH4G/dPeRjv5YhmVXTGEws9Vm1mZmbUmtpyEiUmp5JX0zG0sq4b/g7v8ULO4xs1nB67OA3mB5F1Cftnkd0D18n+6+wd2b3b25tra22PhFilaKmTudfVd89EVilc/sHQN+Aux1979Le2kz8GDw+EHgpbTlq8ysxszmAQ3A9vBCFhGRYuVzctbngfuBt81sZ7Dsr4BngY1m9jBwBFgJ4O7tZrYReIfUzJ9H3P186JGLyOWlGCbPveLlgYOduk6uXCZn0nf335B5nB7g9izbPAM8M4q4REQkAjojVxIp6WfiSnKp9o5UhhDn6ivhS5Kppy8ikiDq6Ut5qZSzb0UqlHr6IhUiZx39oQurFHhxFVXaTBYlfRGRBNHwjkgRDn50JPMLKromZU49fZGQ9JwdoOfsQNxhiIxIPX2RClLU9XFP7L/0uHZBeMFIRVJPX0QkQdTTFxmlcYePxx2CSN7U0xepcgNHjjFwZOQDzJq2mRxK+iIJNHCwM+4QJCZK+iIiCaKkLyKSIEr6IiIJotk7ItWowPo7khzq6YuIJIiSvohIgmh4R+KnGvoiJaOevkhC6AQtgTySvpk9Z2a9ZrYnbdnTZnbMzHYGt7vSXnvSzDrMbL+ZLY8qcJFK0dnXHXcIIhfl09N/HmjJsPzv3b0puP0SwMxuBFYBjcE268ysiLKAIiIShZxJ391fB07nub8VwIvuPuDuh4AOYMko4hMJ1YHevrhDKEpRJZVFMhjNmP6jZrY7GP6ZEiybAxxNW6crWCYiZWDgYKfq7iRcsUn/h8B1QBNwHPh+sNwyrOuZdmBmq82szczaTpw4UWQYIvEYd/i4SipLRSoq6bt7j7ufd/cLwI+5NITTBdSnrVoHZDyK5e4b3L3Z3Ztra2uLCUNERApUVNI3s1lpT+8Fhmb2bAZWmVmNmc0DGoDtowtRqkb3W5duVaysr5P7/uFLt/TLKEpi5Dw5y8x+BiwFpptZF/DXwFIzayI1dHMY+BqAu7eb2UbgHWAQeMTdz0cTukg8Dn50hHHB42oY4hk42EnNdfPjDkNKJGfSd/evZlj8kxHWfwZ4ZjRBiYhINFSGQURSTuyHsUFKmH1LvLFIZJT0JVpVOH5fDUM6EJRlODs27jCkxJT0JR5V+GMgUglUcE1EJEGU9EXkMr3dg3GHIBFS0heRi5Twq5+SvohIgijpi4gkiGbviMjI0mdaaf5+xVNPX0QkQZT0RUQSRMM7InnyzvfyXrfn7AAzr6mJMBqR4qinL5KHQhK+SDlTT18kyd4/nLo/8UmsYUjpqKcviRD3BdE7+zJeQE6k5JT0JTHiTvwi5UBJX0QkQTSmLyJXUunrqqWkL1UvzGGdro//LbR9VSSdnVvxlPRFytjJ02OYPvV85O8zcOTYxcc1186J/P0kPkr6Ej4NDYiULR3IFSlzJ0+PiTsEqSI5k76ZPWdmvWa2J23ZVDPbYmYHgvspaa89aWYdZrbfzJZHFbiIiBQun57+80DLsGVPAFvdvQHYGjzHzG4EVgGNwTbrzEzdFImN5uaLXC5n0nf314HTwxavAFqDx63APWnLX3T3AXc/BHQAS0KKVaSk4q63E+ewji6bWL2KHdOf6e7HAYL7GcHyOcDRtPW6gmUiidNzdiDuEESuEPbsHcuwzDOuaLYaWA1w7bXXhhyGSDji7u2LhK3Ynn6Pmc0CCO57g+VdQH3aenVAxkpT7r7B3Zvdvbm2trbIMEREpBDFJv3NwIPB4weBl9KWrzKzGjObBzQA20cXooiIhCXn8I6Z/QxYCkw3sy7gr4FngY1m9jBwBFgJ4O7tZrYReAcYBB5x9+hPJ5T46YQskYqQM+m7+1ezvHR7lvWfAZ4ZTVAiIhINnZErIpIgSvoiIgmipC9SIrpkopQDVdmU4ungbU49ZweYeU1N3GEUpbd7kBmzlSKqjXr6ImWoHCprptfYl+qhpC8SMZVjkHKipC8ikiBK+lKVVFJZJDMdpZHcdDFsyUSfi4qkpC8il9EB3Oqm4R0RkQRRT18KU6Vz84fq5tv8T8ccSXlM15TqpaQvMkycF04pt4Tf/24nAFdfPz/mSCQsGt4REUkQ9fTlkiqbjVHMtE1dHvFyJ0+PYfpUXRKjmqinLyKSIOrpS2ZVesBWJOnU0xfJU9fH/1b0tj1nB1SDR8qCevoiEq4qOzZUbZT0JfFKffC2s6+b+RNnl/Q9w9T/bueVUzjzGQ7Uj0FZUNIXKaFKvHrWydNjqC92Yx0bKjsa0xeRnIZO0pLKN6qevpkdBj4AzgOD7t5sZlOBnwNzgcPAf3P3M6MLU6T6ldvZuFKdwujp/xd3b3L35uD5E8BWd28AtgbPRUSkDEQxvLMCaA0etwL3RPAeIlnpAioi2Y026TvwipntMLPVwbKZ7n4cILifMcr3EKka2ebqa2hHSmW0s3c+7+7dZjYD2GJm+/LdMPiRWA1w7bXXjjIMKZpmV4gkyqh6+u7eHdz3ApuAJUCPmc0CCO57s2y7wd2b3b25trZ2NGGIiEieik76ZvYpM5s09Bi4E9gDbAYeDFZ7EHhptEGKRME731NVTUmc0QzvzAQ2mdnQfv6vu//azP4AbDSzh4EjwMrRhykicdNc/epQdNJ3907g5gzLTwG3jyYoiVgVj+PnmrlT7GURR1NsrRoMHWhWbf3KpzIMSVHFiV5E8qcyDFI1ND9fJDclfZEcohjaOXq6P/R9llL/u50a469QGt4RkdJTmeXYqKcvVaHShnaSfmBY4qOkL4mk+fnFU8mIyqbhnWqmGTtlqZqulZvxKlpS1pT0K93wxK7x0Ypy8vSYqpr7PnRwt+gfAo31R05JX6qehnJELlHSrzYa0rmMEn7paApnZVDSF4lZ4g+MqqNSUpq9IyJ5S/wPVBVQ0peKF+Uc/Sjn07cfHoxs3yLZaHhHJEY9ZweYeU1N3GEULH3WUUnG8jWrJzRK+uVMH3SpYKOevimRUNKvFDrYVZByn7XT/8HVpHfwK7XHH6l8PvPqGBVMY/oiMeoZOFnRZ+jqwG7lUU9fKlpUB3FLURCtZ+Bk5O8RpXwTfuilGvRX76go6UtV8M73Cr4EYqn1f3A1AFdPquxa+vmI/UQtDftkpaQfF30oRSQGSvrlQH+u5uVAbx8NMyZefDxcOR+8Herlj6SSD+aWdeE4dbAuowO5UlGiTvalGMsf/gNQTWP7ucb504d9YhkC6n7r0i2hIuvpm1kL8L+AMcD/cfdno3ovqS7DD/yl9/CHnkN59+yLVcm9/ZFkS/ahzuVPcCIvRCRJ38zGAD8AlgFdwB/MbLO7vxPF+5UdffhC0/9uJ0yewYHevssO1kaZ8OO8lGHPwElmMqfik3/6cE8+Qz/Df+hLcnGWfL+nVTYkFFVPfwnQ4e6dAGb2IrACqMykn8+YoBJ96A709lFH5fbo+z+4mqsn9ec1ng+Xhnkyzduvph+Bipbte15BPwxRJf05wNG0513AZyN6r/wSbvp/SrYkruReEkN/0ndNnkHd+71A6s/7/nc7OXo6NZ2x/t1O/HT/ZR8iiOYHIN+e/VASz7QsW4IvNOGn233iGDNrpl9aJ+3HoBx/AIbG89PvC0n0w8f4058P9frT/wLI9HokwvrOl8kPRlRJ3zIs88tWMFsNrA6e9pnZ/ohiAZgOVPbRskuqqS1QXe1RW8pTNbUFLm9PwSenRJX0u4D6tOd1QHf6Cu6+AdgQ0ftfxsza3L25FO8VtWpqC1RXe9SW8lRNbYHRtyeqKZt/ABrMbJ6ZjQNWAZsjei8REclTJD19dx80s0eBl0lN2XzO3dujeC8REclfZPP03f2XwC+j2n+BSjKMVCLV1BaorvaoLeWpmtoCo2yPuXvutUREpCqoDIOISIJUVdI3s5Vm1m5mF8ysOW35XDPrN7OdwW192muLzextM+sws7Vmlmm6aSyytSd47ckg5v1mtjxtedm2Z4iZPW1mx9L+P+5Key1ju8qZmbUE8XaY2RNxx1MoMzscfGZ2mllbsGyqmW0xswPB/ZS448zGzJ4zs14z25O2LGv85fwZy9KWcL8v7l41N+AGYAHwGtCctnwusCfLNtuBz5E6t+BXwF/E3Y482nMjsAuoAeYBB4Ex5d6etPifBr6RYXnWdpXrjdREhYPAfGBcEP+NccdVYBsOA9OHLfufwBPB4yeA/xF3nCPE/5+B/5T+Hc8Wf7l/xrK0JdTvS1X19N19r7vnfZKXmc0CrnH3bZ76V/wpcE9kARZohPasAF509wF3PwR0AEvKvT15yNiumGPK5WLJEXc/BwyVHKl0K4DW4HErZfw5cvfXgdPDFmeLv6w/Y1nakk1RbamqpJ/DPDN7y8z+1cy+GCybQ+pEsiFdwbJyl6nMxRwqqz2Pmtnu4M/ZoT+9s7WrnFVizMM58IqZ7QjOlAeY6e7HAYL7GbFFV5xs8Vfq/1do35eKu4iKmb0K/IcML33L3V/Kstlx4Fp3P2Vmi4H/Z2aN5FEuImpFtidb3LG3Z8hI7QJ+CPwNqdj+Bvg+8N8po/gLUIkxD/d5d+82sxnAFjPbF3dAEarE/69Qvy8Vl/Td/Y4ithkABoLHO8zsIHA9qV/GurRVrygXEbVi2kP2Mhext2dIvu0ysx8Dvwie5izfUYYqMebLuHt3cN9rZptIDRH0mNksdz8eDBv2xhpk4bLFX3H/X+7eM/Q4jO9LIoZ3zKw2qPGPmc0HGoDO4M++D8zs1mCWywNAtt51OdkMrDKzGjObR6o92yulPcGXcMi9wNBMhYztKnV8BarokiNm9ikzmzT0GLiT1P/HZuDBYLUHKcPPUQ7Z4q+4z1jo35e4j1aHfOT7XlK/fgNAD/BysPwrQDupI91/BP5r2jbNwT/iQeB/E5ywVg63bO0JXvtWEPN+0mbolHN70mL8B+BtYHfwwZ2Vq13lfAPuAt4N4v5W3PEUGPv84HuxK/iOfCtYPg3YChwI7qfGHesIbfgZqSHcT4Lvy8MjxV/On7EsbQn1+6IzckVEEiQRwzsiIpKipC8ikiBK+iIiCaKkLyKSIEr6IiIJoqQvIpIgSvoiIgmipC8ikiD/H2UEAoTNazCCAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"mu = [-10, -5, 1, 20, 35]\n",
"sigma = [10, 40, 5, 20, 10]\n",
"\n",
"def multi_normal(mu, sigma, size=None):\n",
" assert len(mu) == len(sigma)\n",
" return np.stack([\n",
" sp.norm(loc=mu[i], scale=sigma[i]).rvs(size=size)\n",
" for i in range(len(mu))\n",
" ], axis=1)\n",
"\n",
"def plot_dist(samples, **kwargs):\n",
" for i in range(samples.shape[1]):\n",
" plt.hist(samples[:, i], label=f'arm {i}', **kwargs)\n",
" plt.legend()\n",
" \n",
"def draw_rewards(mu, sigma, size=None):\n",
" return (multi_normal(mu, sigma, size=size) > 0.5).astype(int)\n",
"\n",
"plot_dist(multi_normal(mu, sigma, size=10_000), bins=100, alpha=0.2)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([ 3., 3., 4., 10., 10.]), array([0, 3, 3, 3, 3, 3, 3, 3, 3, 3]), 9)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def simulate_game(n, mu, sigma, strategy='maxwin'):\n",
" \n",
" rewards = np.zeros(shape=(n, len(mu)))\n",
" choices = []\n",
" collected_rewards = 0\n",
" \n",
" for i in range(n):\n",
" \n",
" if i == 0 or strategy == 'random':\n",
" choice = np.random.randint(len(mu))\n",
" if strategy == 'maxwin':\n",
" choice = np.argmax(np.sum(rewards[:i, :], axis=0))\n",
" elif strategy == 'meansd':\n",
" choice = np.argmax(np.mean(rewards[:i, :], axis=0) / np.std(rewards[:i, :], axis=0))\n",
" \n",
" result = draw_rewards(mu, sigma, 1)[0, :]\n",
" \n",
" rewards[i, :] = result\n",
" collected_rewards += result[choice]\n",
" choices.append(choice)\n",
" \n",
" return np.sum(rewards, axis=0), np.array(choices), collected_rewards \n",
"\n",
"\n",
"simulate_game(10, mu, sigma)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [11:18<00:00, 1.47it/s]\n",
" 0%| | 0/1000 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"random: 22.556 vs 23.031 vs 29.829\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [11:13<00:00, 1.49it/s]\n",
" 0%| | 0/1000 [00:00<?, ?it/s]/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice.\n",
" out=out, **kwargs)\n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:154: RuntimeWarning: invalid value encountered in true_divide\n",
" ret, rcount, out=ret, casting='unsafe', subok=False)\n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:217: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
" keepdims=keepdims)\n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:186: RuntimeWarning: invalid value encountered in true_divide\n",
" arrmean, rcount, out=arrmean, casting='unsafe', subok=False)\n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:207: RuntimeWarning: invalid value encountered in true_divide\n",
" ret, rcount, out=ret, casting='unsafe', subok=False)\n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:14: RuntimeWarning: divide by zero encountered in true_divide\n",
" \n",
"/home/tymek/.miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:14: RuntimeWarning: invalid value encountered in true_divide\n",
" \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"maxwin: 23.251 vs 22.885 vs 29.782\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [11:33<00:00, 1.44it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"meansd: 23.166 vs 22.885 vs 29.782\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"R = 1000\n",
"N = 50\n",
"k = 20\n",
"\n",
"for strategy in ['random', 'maxwin', 'meansd']:\n",
" \n",
" rewards = 0 # actual rewards given the choosen strategy\n",
" oracle = 0 # best strategy if we hade an oracle, that knew the future outcome\n",
" bestmu = 0 # choose arm based on the known, true expected reward\n",
" \n",
" for i in trange(R):\n",
" np.random.seed(i)\n",
" m = sp.uniform(-10, 2).rvs(k)\n",
" s = sp.uniform(20, 200).rvs(k)\n",
" \n",
" total_rewards, _, collected_rewards = simulate_game(N, m, s, strategy=strategy)\n",
" \n",
" rewards += collected_rewards\n",
" oracle += np.max(total_rewards)\n",
" bestmu += total_rewards[np.argmax(m)]\n",
" \n",
" print(f'{strategy}: {rewards / R} vs {bestmu / R} vs {oracle / R}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment