Last active
August 4, 2023 22:19
-
-
Save cadurosar/bd54c723c1d6335a43c8 to your computer and use it in GitHub Desktop.
A interactive ipython notebook for: Keras plays catch - https://gist.github.com/EderSantana/c7222daa328f0e885093
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Action End Test, Points: 0\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAFSCAYAAAB2cI2KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADV5JREFUeJzt3E2IlXX7wPHr/BVMx0QHdUQTQTeCYzIbQSaZjaPQxkR8\nC1820YNFoCtFSciNaIsgSkISXBqSDi4USxAxFAvEGISCFAoNmUxH0REJu59F0PP08B/PXPN2zpn7\n84GzCM99vK6yr7/7zMypFEVRBAAD8n+1HgCgkYgmQIJoAiSIJkCCaAIkiCZARjHMImLQj+7u7iFd\n36iPMu5t5/I8GnXv/lSG+/s0K5XKoK8timJI1zeqMu5t5/Jo1L37S6Pbc4AE0QRIEE2ABNEESBBN\ngATRBEgQTYAE0QRIEE2AhPEDedKBAwfi+++/j0qlEnv27InFixeP9FwA9anaz5J/++23xb/+9a+i\nKIrip59+KjZs2DBiP3s+1Osb9VHGve1cnkej7t2fqrfnV65ciRUrVkRExIIFC+LRo0fx5MmTapcB\njElVo3nv3r1obm7++5+nTZsW9+7dG9GhAOrVgN7T/G/VPhSpu7s7WltbBz3QMH/oUsMo4952Lo+x\ntHfVaM6cOfMfJ8uenp6YMWNGv88fyheJGvUjpIaqjHvbuTwade9BfzRce3t7nDt3LiIibty4ES0t\nLTFp0qThnQ6gQVQ9aba1tcWiRYti48aNMW7cuNi3b99ozAVQl3xyex0o4952Lo9G3dsntwMMA9EE\nSBBNgATRBEgQTYAE0QRIEE2ABNEESBBNgATRBEgQTYAE0QRIEE2ABNEESBBNgATRBEgQTYAE0QRI\nEE2ABNEESBBNgATRBEgQTYAE0QRIEE2ABNEESBBNgATRBEgQTYAE0QRIEE2ABNEESBBNgATRBEgQ\nTYAE0QRIEE2ABNEESBBNgATRBEgQTYAE0QRIEE2ABNEESBBNgATRBEgQTYAE0QRIEE2ABNEESBBN\ngATRBEgYP5AnHTp0KK5duxbPnz+Pt99+Ozo7O0d6LoC6VDWaV69ejZs3b8bx48ejt7c31qxZI5pA\naVWN5tKlS2PJkiURETFlypR4+vRpFEURlUplxIcDqDdV39OsVCrx0ksvRUTEiRMnoqOjQzCB0hrQ\ne5oREefPn4+TJ0/G0aNHR3IegLo2oGheunQpjhw5EkePHo3Jkye/8Lnd3d3R2to66IGKohj0tY2s\njHvbuTzG0t6Voso2jx8/jjfffDOOHTsWzc3N1V9wCLfuZX2vtIx727k8GnXv/tJY9aR55syZ6O3t\njR07dvy9/KFDh2LWrFnDPiRAvat60ky/oJNmWhn3tnN5NOre/aXRTwQBJIgmQIJoAiSIJkCCaAIk\niCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSI\nJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgm\nQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkDCgKL57Nmz\n6OzsjK6urpGeB6CuDSiahw8fjqlTp470LAB1r2o0b926Fbdu3YqOjo7RmAegrlWN5sGDB2P37t2j\nMQtA3XthNLu6uqKtrS3mzJkTERFFUYzKUAD1avyLfvHixYtx+/btuHDhQty9ezcmTJgQs2bNimXL\nlvV7TXd3d7S2tg56oLKGuYx727k8xtLelWKA23zyySfxyiuvxBtvvPHiF6xUBj1MURRDur5RlXFv\nO5dHo+7dXxp9nyZAwoBPmgN+QSfNtDLubefyaNS9nTQBhoFoAiSIJkCCaAIkiCZAgmgCJIgmQIJo\nAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgC\nJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIk\niCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAwoCiefr06Vi9enWs\nXbs2Ll68ONIzAdSvoooHDx4UK1euLPr6+orffvuteP/991/4/IgY9GOo1zfqo4x727k8j0bduz/j\no4rLly9He3t7TJw4MSZOnBj79++vdgnAmFX19vzOnTvx9OnT2L59e2zevDmuXLkyGnMB1KWqJ82i\nKKK3tzcOHz4ct2/fjq1bt8aFCxf6fX53d3e0trYOeqC/TvLlU8a97VweY2nvqtGcPn16tLW1RaVS\niblz50ZTU1Pcv38/mpub/9/nL168eNDDFEURlUpl0Nc3qjLubefyaNS9+wt91dvz9vb2uHr1ahRF\nEQ8ePIi+vr5+gwkw1lU9aba0tMSqVati/fr1UalUYt++faMxF0BdqhTD/GbDUI7hjXqMH6oy7m3n\n8mjUvQd9ew7Af4gmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJo\nAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgC\nJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIkiCZAgmgCJIgmQIJoAiSIJkCCaAIk\niCZAgmgCJIgmQIJoAiSIJkDC+GpP6Ovri127dsXDhw/jjz/+iHfffTdee+210ZgNoO5UjeapU6di\n/vz5sXPnzujp6Ylt27bF2bNnR2M2gLpT9fZ82rRp8eDBg4iIePjwYTQ3N4/4UAD1qlIURVHtSW+9\n9Vb88ssv8ejRozhy5Ei8+uqr/b9gpTLoYYqiGNL1jaqMe9u5PBp17/7SWPWkefr06Zg9e3Z89dVX\ncezYsfjggw+GfTiARlH1Pc1r167F8uXLIyJi4cKF0dPT88K/Obq7u6O1tXXQAw3g4DsmlXFvO5fH\nWNq7ajTnzZsX169fj87Ozrhz5040NTW98Ki9ePHiQQ/TqMf4oSrj3nYuj0bdu7/QV31Ps6+vL/bs\n2RO///57PH/+PHbs2BFLly7t9/ne08wr4952Lo9G3XvQ0cwSzbwy7m3n8mjUvQf9hSAA/kM0ARJE\nEyBBNAESRBMgQTQBEkQTIEE0ARJEEyBBNAESRBMgoeqnHFHdn3/+WRevMdqG+rEFz58/H9R148aN\nG9LvC0PhpAmQIJoACaIJkCCaAAmiCZAgmgAJogmQIJoACaIJkCCaAAmiCZAgmgAJogmQIJoACaIJ\nkCCaAAmiCZAgmgAJogmQIJoACaIJkCCaAAmiCZAgmgAJogmQIJoACaIJkCCaAAmiCZAwvtYDjAVF\nUQzp+kqlMujXqFQqQ/q9gZxKMdT/4wFKxO05QIJoAiSIJkCCaAIkiCZAgmgCJNRNNA8cOBAbN26M\nTZs2RXd3d63HGRWHDh2KjRs3xrp16+Lrr7+u9Tij5tmzZ9HZ2RldXV21HmXUnD59OlavXh1r166N\nixcv1nqcEdfX1xfvvfdebN26NTZt2hTffPNNrUcaNnXxze3fffdd/Pzzz3H8+PG4efNm7N27N44f\nP17rsUbU1atX4+bNm3H8+PHo7e2NNWvWRGdnZ63HGhWHDx+OqVOn1nqMUdPb2xuffvppdHV1xZMn\nT+Ljjz+Ojo6OWo81ok6dOhXz58+PnTt3Rk9PT2zbti3Onj1b67GGRV1E88qVK7FixYqIiFiwYEE8\nevQonjx5Ek1NTTWebOQsXbo0lixZEhERU6ZMiadPn0ZRFGP+J3xu3boVt27dGvPR+G+XL1+O9vb2\nmDhxYkycODH2799f65FG3LRp0+LHH3+MiIiHDx9Gc3NzjScaPnVxe37v3r1//EudNm1a3Lt3r4YT\njbxKpRIvvfRSREScOHEiOjo6xnwwIyIOHjwYu3fvrvUYo+rOnTvx9OnT2L59e2zevDmuXLlS65FG\n3Ouvvx6//vprrFy5MrZs2RK7du2q9UjDpi5Omv+rTD/Zef78+Th58mQcPXq01qOMuK6urmhra4s5\nc+ZERHn+OxdFEb29vXH48OG4fft2bN26NS5cuFDrsUbU6dOnY/bs2fH555/HDz/8EHv37o0vv/yy\n1mMNi7qI5syZM/9xsuzp6YkZM2bUcKLRcenSpThy5EgcPXo0Jk+eXOtxRtzFixfj9u3bceHChbh7\n925MmDAhZs2aFcuWLav1aCNq+vTp0dbWFpVKJebOnRtNTU1x//79MXXL+r+uXbsWy5cvj4iIhQsX\nRk9Pz5h5+6kubs/b29vj3LlzERFx48aNaGlpiUmTJtV4qpH1+PHj+PDDD+Ozzz6Ll19+udbjjIqP\nPvooTpw4EV988UWsW7cu3nnnnTEfzIi//nxfvXo1iqKIBw8eRF9f35gOZkTEvHnz4vr16xHx19sT\nTU1NYyKYEXVy0mxra4tFixbFxo0bY9y4cbFv375ajzTizpw5E729vbFjx46//wY+dOhQzJo1q9aj\nMcxaWlpi1apVsX79+qhUKqX4871hw4bYs2dPbNmyJZ4/fz6mvvjlo+EAEuri9hygUYgmQIJoAiSI\nJkCCaAIkiCZAgmgCJIgmQMK/ARRcnZLdxlbtAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x7f3431d0d610>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"train(model)\n", | |
"test(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using Theano backend.\n", | |
"ERROR (theano.sandbox.cuda): nvcc compiler not found on $PATH. Check your nvcc installation and try again.\n", | |
"ERROR:theano.sandbox.cuda:nvcc compiler not found on $PATH. Check your nvcc installation and try again.\n" | |
] | |
} | |
], | |
"source": [ | |
"%matplotlib inline\n", | |
"import seaborn\n", | |
"seaborn.set()\n", | |
"import json\n", | |
"import matplotlib.pyplot as plt\n", | |
"import numpy as np\n", | |
"import time\n", | |
"from keras.models import model_from_json\n", | |
"from qlearn import Catch\n", | |
"from PIL import Image\n", | |
"from IPython import display\n", | |
"last_frame_time = 0\n", | |
"translate_action = [\"Left\",\"Stay\",\"Right\",\"Create Ball\",\"End Test\"]\n", | |
"grid_size = 10\n", | |
"\n", | |
"def display_screen(action,points,input_t):\n", | |
" global last_frame_time\n", | |
" display.clear_output(wait=True)\n", | |
" print \"Action %s, Points: %d\" % (translate_action[action],points)\n", | |
" if(\"End\" not in translate_action[action]):\n", | |
" plt.imshow(input_t.reshape((grid_size,)*2),\n", | |
" interpolation='none', cmap='gray')\n", | |
" display.display(plt.gcf())\n", | |
" last_frame_time = set_max_fps(last_frame_time)\n", | |
"def set_max_fps(last_frame_time,FPS = 1):\n", | |
" current_milli_time = lambda: int(round(time.time() * 1000))\n", | |
" sleep_time = 1./FPS - (current_milli_time() - last_frame_time)\n", | |
" if sleep_time > 0:\n", | |
" time.sleep(sleep_time)\n", | |
" return current_milli_time()\n", | |
"def test(model):\n", | |
" global last_frame_time\n", | |
" plt.ion()\n", | |
" # Define environment, game\n", | |
" env = Catch(grid_size)\n", | |
" c = 0\n", | |
" last_frame_time = 0\n", | |
" points = 0\n", | |
" for e in range(10):\n", | |
" loss = 0.\n", | |
" env.reset()\n", | |
" game_over = False\n", | |
" # get initial input\n", | |
" input_t = env.observe()\n", | |
" display_screen(3,points,input_t)\n", | |
" c += 1\n", | |
" while not game_over:\n", | |
" input_tm1 = input_t\n", | |
" # get next action\n", | |
" q = model.predict(input_tm1)\n", | |
" action = np.argmax(q[0])\n", | |
" # apply action, get rewards and new state\n", | |
" input_t, reward, game_over = env.act(action)\n", | |
" points += reward\n", | |
" display_screen(action,points,input_t)\n", | |
" c += 1\n", | |
" display_screen(4,points,input_t)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import json\n", | |
"import numpy as np\n", | |
"from keras.models import Sequential\n", | |
"from keras.layers.core import Dense\n", | |
"from keras.optimizers import sgd\n", | |
"\n", | |
"\n", | |
"class Catch(object):\n", | |
" def __init__(self, grid_size=10):\n", | |
" self.grid_size = grid_size\n", | |
" self.reset()\n", | |
"\n", | |
" def _update_state(self, action):\n", | |
" \"\"\"\n", | |
" Input: action and states\n", | |
" Ouput: new states and reward\n", | |
" \"\"\"\n", | |
" state = self.state\n", | |
" if action == 0: # left\n", | |
" action = -1\n", | |
" elif action == 1: # stay\n", | |
" action = 0\n", | |
" else:\n", | |
" action = 1 # right\n", | |
" f0, f1, basket = state[0]\n", | |
" new_basket = min(max(1, basket + action), self.grid_size-1)\n", | |
" f0 += 1\n", | |
" out = np.asarray([f0, f1, new_basket])\n", | |
" out = out[np.newaxis]\n", | |
"\n", | |
" assert len(out.shape) == 2\n", | |
" self.state = out\n", | |
"\n", | |
" def _draw_state(self):\n", | |
" im_size = (self.grid_size,)*2\n", | |
" state = self.state[0]\n", | |
" canvas = np.zeros(im_size)\n", | |
" canvas[state[0], state[1]] = 1 # draw fruit\n", | |
" canvas[-1, state[2]-1:state[2] + 2] = 1 # draw basket\n", | |
" return canvas\n", | |
" \n", | |
" def _get_reward(self):\n", | |
" fruit_row, fruit_col, basket = self.state[0]\n", | |
" if fruit_row == self.grid_size-1:\n", | |
" if abs(fruit_col - basket) <= 1:\n", | |
" return 1\n", | |
" else:\n", | |
" return -1\n", | |
" else:\n", | |
" return 0\n", | |
"\n", | |
" def _is_over(self):\n", | |
" if self.state[0, 0] == self.grid_size-1:\n", | |
" return True\n", | |
" else:\n", | |
" return False\n", | |
"\n", | |
" def observe(self):\n", | |
" canvas = self._draw_state()\n", | |
" return canvas.reshape((1, -1))\n", | |
"\n", | |
" def act(self, action):\n", | |
" self._update_state(action)\n", | |
" reward = self._get_reward()\n", | |
" game_over = self._is_over()\n", | |
" return self.observe(), reward, game_over\n", | |
"\n", | |
" def reset(self):\n", | |
" n = np.random.randint(0, self.grid_size-1, size=1)\n", | |
" m = np.random.randint(1, self.grid_size-2, size=1)\n", | |
" self.state = np.asarray([0, n, m])[np.newaxis]\n", | |
"\n", | |
"\n", | |
"class ExperienceReplay(object):\n", | |
" def __init__(self, max_memory=100, discount=.9):\n", | |
" self.max_memory = max_memory\n", | |
" self.memory = list()\n", | |
" self.discount = discount\n", | |
"\n", | |
" def remember(self, states, game_over):\n", | |
" # memory[i] = [[state_t, action_t, reward_t, state_t+1], game_over?]\n", | |
" self.memory.append([states, game_over])\n", | |
" if len(self.memory) > self.max_memory:\n", | |
" del self.memory[0]\n", | |
"\n", | |
" def get_batch(self, model, batch_size=10):\n", | |
" len_memory = len(self.memory)\n", | |
" num_actions = model.output_shape[-1]\n", | |
" env_dim = self.memory[0][0][0].shape[1]\n", | |
" inputs = np.zeros((min(len_memory, batch_size), env_dim))\n", | |
" targets = np.zeros((inputs.shape[0], num_actions))\n", | |
" for i, idx in enumerate(np.random.randint(0, len_memory,\n", | |
" size=inputs.shape[0])):\n", | |
" state_t, action_t, reward_t, state_tp1 = self.memory[idx][0]\n", | |
" game_over = self.memory[idx][1]\n", | |
"\n", | |
" inputs[i:i+1] = state_t\n", | |
" # There should be no target values for actions not taken.\n", | |
" # Thou shalt not correct actions not taken #deep\n", | |
" targets[i] = model.predict(state_t)[0]\n", | |
" Q_sa = np.max(model.predict(state_tp1)[0])\n", | |
" if game_over: # if game_over is True\n", | |
" targets[i, action_t] = reward_t\n", | |
" else:\n", | |
" # reward_t + gamma * max_a' Q(s', a')\n", | |
" targets[i, action_t] = reward_t + self.discount * Q_sa\n", | |
" return inputs, targets\n", | |
"\n", | |
" \n", | |
"# parameters\n", | |
"epsilon = .1 # exploration\n", | |
"num_actions = 3 # [move_left, stay, move_right]\n", | |
"epoch = 1000\n", | |
"max_memory = 500\n", | |
"hidden_size = 100\n", | |
"batch_size = 1\n", | |
"grid_size = 10\n", | |
"\n", | |
"model = Sequential()\n", | |
"model.add(Dense(hidden_size, input_shape=(grid_size**2,), activation='relu'))\n", | |
"model.add(Dense(hidden_size, activation='relu'))\n", | |
"model.add(Dense(num_actions))\n", | |
"model.compile(sgd(lr=.2), \"mse\")\n", | |
" \n", | |
"# If you want to continue training from a previous model, just uncomment the line bellow\n", | |
"# model.load_weights(\"model.h5\")\n", | |
"\n", | |
"# Define environment/game\n", | |
"env = Catch(grid_size)\n", | |
"\n", | |
"# Initialize experience replay object\n", | |
"exp_replay = ExperienceReplay(max_memory=max_memory)\n", | |
"\n", | |
"def train(model):\n", | |
" # Train\n", | |
" win_cnt = 0\n", | |
" for e in range(1):\n", | |
" loss = 0.\n", | |
" env.reset()\n", | |
" game_over = False\n", | |
" # get initial input\n", | |
" input_t = env.observe()\n", | |
"\n", | |
" while not game_over:\n", | |
" input_tm1 = input_t\n", | |
" # get next action\n", | |
" if np.random.rand() <= epsilon:\n", | |
" action = np.random.randint(0, num_actions, size=1)\n", | |
" else:\n", | |
" q = model.predict(input_tm1)\n", | |
" action = np.argmax(q[0])\n", | |
"\n", | |
" # apply action, get rewards and new state\n", | |
" input_t, reward, game_over = env.act(action)\n", | |
" if reward == 1:\n", | |
" win_cnt += 1\n", | |
"\n", | |
" # store experience\n", | |
" exp_replay.remember([input_tm1, action, reward, input_t], game_over) \n", | |
" \n", | |
" # adapt model\n", | |
" inputs, targets = exp_replay.get_batch(model, batch_size=batch_size)\n", | |
" \n", | |
" display_screen(action,3000,inputs[0]) \n", | |
" \n", | |
" loss += model.train_on_batch(inputs, targets)[0]\n", | |
" print(\"Epoch {:03d}/999 | Loss {:.4f} | Win count {}\".format(e, loss, win_cnt))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I got this error message in line
-> 166 loss += model.train_on_batch(inputs, targets)[0]
invalid index to scalar variable.
I am using Python 3 on window anaconda .. TensorFlow backend