Skip to content

Instantly share code, notes, and snippets.

@khanrc
Created January 30, 2018 14:12
Show Gist options
  • Save khanrc/fe36cd1e7e60f61c90b5a6d484fadb7a to your computer and use it in GitHub Desktop.
Save khanrc/fe36cd1e7e60f61c90b5a6d484fadb7a to your computer and use it in GitHub Desktop.
Gradient bandit algorithm implementation for 10-armed bandit testbed
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# 예제 그림에 나온대로 어림하여 넣음\n",
"mean = np.array([0.2, -0.8, 1.5, 0.4, 1.2, -1.6, -0.2, -1.1, 0.8, -0.4])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.2, -0.8, 1.5, 0.4, 1.2, -1.6, -0.2, -1.1, 0.8, -0.4])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-1.51066622, 0.27000444, 2.15270565, -1.05204115, 1.0553461 ,\n",
" -1.86286236, -0.72597318, 0.73864357, 2.73526509, -1.37458381])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.random.randn(10) + mean"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def get_reward():\n",
" return np.random.randn(10) + mean"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def softmax(H):\n",
" h = H - np.max(H)\n",
" exp = np.exp(h)\n",
" return exp / np.sum(exp)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def gradient_bandit(N):\n",
" H = np.zeros(10) # preference\n",
" r_hist = []\n",
" alpha = 0.1 # 적당히 잡음\n",
" for t in range(1, N):\n",
" policy = softmax(H) # policy pi\n",
" # sampling (choice) action by policy\n",
" a = np.random.choice(10, p=policy) \n",
" rewards = get_reward()\n",
" r = rewards[a] # R_t (reward for chosen action)\n",
" r_hist.append(r)\n",
" avg_r = np.average(r_hist)\n",
" # update a == A_t (chosen action)\n",
" H[a] = H[a] + alpha*(r-avg_r)*(1-policy[a])\n",
" # update a != A_t (non-chosen action)\n",
" H[:a] = H[:a] - alpha*(r-avg_r)*policy[:a]\n",
" H[a+1:] = H[a+1:] - alpha*(r-avg_r)*policy[a+1:]\n",
" \n",
" return softmax(H), r_hist"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"opt_policy, r_hist = gradient_bandit(10000)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f3e4a7bff90>]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADj1JREFUeJzt3X+s3fVdx/Hnay247ofD2GsibVmb2KHN1EBOGEqii2MBpmmXaBSS+WMh6z8yUUkNqMEF/5o1i5rhtJlz2ZxDREJutNolDrPECOntOmFtrbnpNtoLhjtG0bgqFN/+cc8dp3dtz7nlXL7nfu7zkRDu93s+vd93vtBnvz3fc+5JVSFJasvruh5AkjR+xl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalBxl2SGmTcJalB67s68MaNG2vr1q1dHV6SVqVDhw59vaqmhq3rLO5bt25lZmamq8NL0qqU5GujrPNpGUlqkHGXpAYZd0lqkHGXpAYNjXuSTyR5NsmXL/B4kvxRktkkTyS5dvxjSpKWY5RXy3wS+CjwqQs8fguwvf/PO4CP9f+tFfbI4Tn2HjjO06fPcOUVG9hz09W895pNXY8laQIMvXKvqi8A37jIkl3Ap2rBY8AVSb53XAPq/B45PMc9Dz/J3OkzFDB3+gz3PPwkjxye63o0SRNgHM+5bwJODmyf6u/TCtp74DhnXnr5nH1nXnqZvQeOdzSRpEnymt5QTbI7yUySmfn5+dfy0M15+vSZZe2XtLaMI+5zwJaB7c39fd+mqvZVVa+qelNTQ989q4u48ooNy9ovaW0ZR9yngV/ov2rmeuCFqnpmDN9XF7HnpqvZcNm6c/ZtuGwde266uqOJJE2Soa+WSfJZ4J3AxiSngN8BLgOoqj8B9gPvAWaBbwLvX6lh9YrFV8X4ahlJ55Oq6uTAvV6v/MFhkrQ8SQ5VVW/YOt+hKkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNMu6S1CDjLkkNGinuSW5OcjzJbJK7z/P4VUkeTXI4yRNJ3jP+USVJoxoa9yTrgPuBW4AdwG1JdixZ9tvAg1V1DXAr8MfjHlSSNLpRrtyvA2ar6kRVvQg8AOxasqaA7+x//Rbg6fGNKElarvUjrNkEnBzYPgW8Y8maDwGfS/JB4I3AjWOZTpJ0ScZ1Q/U24JNVtRl4D/DpJN/2vZPsTjKTZGZ+fn5Mh5YkLTVK3OeALQPbm/v7Bt0OPAhQVf8CvB7YuPQbVdW+qupVVW9qaurSJpYkDTVK3A8C25NsS3I5CzdMp5eseQp4F0CSH2Ah7l6aS1JHhsa9qs4CdwAHgGMsvCrmSJL7kuzsL7sL+ECSfwU+C/xSVdVKDS1JurhRbqhSVfuB/Uv23Tvw9VHghvGOJkm6VL5DVZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUEjxT3JzUmOJ5lNcvcF1vxskqNJjiT5y/GOKUlajvXDFiRZB9wPvBs4BRxMMl1VRwfWbAfuAW6oqueTfM9KDSxJGm6UK/frgNmqOlFVLwIPALuWrPkAcH9VPQ9QVc+Od0xJ0nKMEvdNwMmB7VP9fYPeBrwtyT8neSzJzeMaUJK0fEOfllnG99kOvBPYDHwhyQ9W1enBRUl2A7sBrrrqqjEdWpK01ChX7nPAloHtzf19g04B01X1UlV9Bfh3FmJ/jqraV1W9qupNTU1d6sySpCFGiftBYHuSbUkuB24FppeseYSFq3aSbGThaZoTY5xTkrQMQ+NeVWeBO4ADwDHgwao6kuS+JDv7yw4AzyU5CjwK7Kmq51ZqaEnSxaWqOjlwr9ermZmZTo4tSatVkkNV1Ru2zneoSlKDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KDRop7kpuTHE8ym+Tui6z76SSVpDe+ESVJyzU07knWAfcDtwA7gNuS7DjPujcDdwKPj3tISdLyjHLlfh0wW1UnqupF4AFg13nW/S7wYeB/xjifJOkSjBL3TcDJge1T/X3fkuRaYEtV/d0YZ5MkXaJXfUM1yeuAjwB3jbB2d5KZJDPz8/Ov9tCSpAsYJe5zwJaB7c39fYveDLwd+KckXwWuB6bPd1O1qvZVVa+qelNTU5c+tSTpokaJ+0Fge5JtSS4HbgWmFx+sqheqamNVba2qrcBjwM6qmlmRiSVJQw2Ne1WdBe4ADgDHgAer6kiS+5LsXOkBJUnLt36URVW1H9i/ZN+9F1j7zlc/liTp1fAdqpLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUIOMuSQ0y7pLUoJHinuTmJMeTzCa5+zyP/3qSo0meSPKPSd46/lElSaMaGvck64D7gVuAHcBtSXYsWXYY6FXVDwEPAb837kElSaMb5cr9OmC2qk5U1YvAA8CuwQVV9WhVfbO/+RiwebxjSpKWY5S4bwJODmyf6u+7kNuBvz/fA0l2J5lJMjM/Pz/6lJKkZRnrDdUk7wN6wN7zPV5V+6qqV1W9qampcR5akjRg/Qhr5oAtA9ub+/vOkeRG4LeAH6+q/x3PeJKkSzHKlftBYHuSbUkuB24FpgcXJLkG+FNgZ1U9O/4xJUnLMTTuVXUWuAM4ABwDHqyqI0nuS7Kzv2wv8Cbgr5N8Kcn0Bb6dJOk1MMrTMlTVfmD/kn33Dnx945jnkiS9Cr5DVZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUHGXZIaZNwlqUEjxT3JzUmOJ5lNcvd5Hv+OJH/Vf/zxJFvHPagkaXTrhy1Isg64H3g3cAo4mGS6qo4OLLsdeL6qvi/JrcCHgZ8b97CPHJ5j74HjPH36DFdesYE9N13Ne6/ZNO7DrJo5JsWknI9JmGMSZnAO54AR4g5cB8xW1QmAJA8Au4DBuO8CPtT/+iHgo0lSVTWuQR85PMc9Dz/JmZdeBmDu9BnuefhJgNf0P9KkzDEpJuV8TMIckzCDczjHolGeltkEnBzYPtXfd941VXUWeAH47nEMuGjvgePfOimLzrz0MnsPHB/nYVbNHJNiUs7HJMwxCTM4h3Msek1vqCbZnWQmycz8/Pyyfu3Tp88sa/9KmZQ5JsWknI9JmGMSZnAO51g0StzngC0D25v7+867Jsl64C3Ac0u/UVXtq6peVfWmpqaWNeiVV2xY1v6VMilzTIpJOR+TMMckzOAczrFolLgfBLYn2ZbkcuBWYHrJmmngF/tf/wzw+XE+3w6w56ar2XDZunP2bbhsHXtuunqch1k1c0yKSTkfkzDHJMzgHM6xaOgN1ao6m+QO4ACwDvhEVR1Jch8wU1XTwJ8Bn04yC3yDhT8AxmrxhkPXd7wnZY5JMSnnYxLmmIQZnMM5FmXMF9gj6/V6NTMz08mxJWm1SnKoqnrD1vkOVUlqkHGXpAYZd0lqkHGXpAYZd0lqUGevlkkyD3ztEn/5RuDrYxxntfN8nMvz8QrPxblaOB9vraqh7wLtLO6vRpKZUV4KtFZ4Ps7l+XiF5+Jca+l8+LSMJDXIuEtSg1Zr3Pd1PcCE8Xycy/PxCs/FudbM+ViVz7lLki5utV65S5IuYtXFfdiHda8VSbYkeTTJ0SRHktzZ9UyTIMm6JIeT/G3Xs3QtyRVJHkryb0mOJfmRrmfqSpJf6/8++XKSzyZ5fdczrbRVFfeBD+u+BdgB3JZkR7dTdeYscFdV7QCuB355DZ+LQXcCx7oeYkL8IfAPVfX9wA+zRs9Lkk3ArwC9qno7Cz+6fOw/lnzSrKq4M/Bh3VX1IrD4Yd1rTlU9U1Vf7H/9Xyz8xl2bP1S+L8lm4CeBj3c9S9eSvAX4MRY+a4GqerGqTnc7VafWAxv6nxT3BuDpjudZcast7qN8WPeak2QrcA3weLeTdO4PgN8A/q/rQSbANmAe+PP+01QfT/LGrofqQlXNAb8PPAU8A7xQVZ/rdqqVt9ririWSvAn4G+BXq+o/u56nK0l+Cni2qg51PcuEWA9cC3ysqq4B/htYk/eoknwXC3/D3wZcCbwxyfu6nWrlrba4j/Jh3WtGkstYCPtnqurhrufp2A3AziRfZeHpup9I8hfdjtSpU8Cpqlr829xDLMR+LboR+EpVzVfVS8DDwI92PNOKW21xH+XDuteEJGHh+dRjVfWRrufpWlXdU1Wbq2orC/9ffL6qmr86u5Cq+g/gZJLFT19+F3C0w5G69BRwfZI39H/fvIs1cHN56AdkT5ILfVh3x2N15Qbg54Enk3ypv+83q2p/hzNpsnwQ+Ez/QugE8P6O5+lEVT2e5CHgiyy8yuwwa+Cdqr5DVZIatNqelpEkjcC4S1KDjLskNci4S1KDjLskNci4S1KDjLskNci4S1KD/h/wx9+wLz5bowAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f3e4c816690>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(opt_policy, 'o')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.01%\t0.01%\t99.84%\t0.02%\t0.04%\t0.01%\t0.02%\t0.01%\t0.03%\t0.01%\t"
]
}
],
"source": [
"for p in opt_policy:\n",
" print \"{:.2%}\\t\".format(p),"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment