Created
September 14, 2020 20:19
-
-
Save kobus-v-schoor/3f432de917ee653f13f5ca8fe828e801 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Implementation of a policy-gradient based agent to solve the cartpole RL problem\n", | |
"# Adapted from https://github.com/awjuliani/DeepRL-Agents/blob/master/Vanilla-Policy.ipynb" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import gym\n", | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"from tensorflow import keras\n", | |
"from tensorflow.keras.layers import Dense\n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"from tqdm import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"env = gym.make('CartPole-v0')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# discounting of rewards is done so that earlier actions are rewarded\n", | |
"gamma = 0.99\n", | |
"def discount_rewards(rewards):\n", | |
" for t in reversed(range(0, len(rewards)-1)):\n", | |
" rewards[t] += rewards[t+1] * gamma\n", | |
" return rewards\n", | |
"\n", | |
"optimizer = keras.optimizers.Adam(lr=1e-3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# softmax output layer ensures outputs can be used as probabilities\n", | |
"model = keras.models.Sequential([\n", | |
" Dense(8, input_shape=env.observation_space.shape, activation='relu'),\n", | |
" Dense(env.action_space.n, activation='softmax')\n", | |
"])\n", | |
"\n", | |
"# responsible actions are rewarded/punished (idx is generated later on)\n", | |
"loss = lambda: -tf.reduce_mean(tf.math.log(tf.gather_nd(model(states), idx)) * rewards)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 5000/5000 [03:05<00:00, 26.91it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"# total number of episodes\n", | |
"num_episodes = 5000\n", | |
"# how frequently (in number of episodes) the network should be trained\n", | |
"update_freq = 10\n", | |
"\n", | |
"# stores the number of steps the pole was kept upright for each episode\n", | |
"steps = []\n", | |
"\n", | |
"# holds the experiences\n", | |
"states = []\n", | |
"actions = []\n", | |
"rewards = []\n", | |
"\n", | |
"for episode in tqdm(range(num_episodes)):\n", | |
" state = env.reset()\n", | |
" done = False\n", | |
" \n", | |
" # will hold the rewards for just this episode\n", | |
" temp_rewards = []\n", | |
"\n", | |
" while not done:\n", | |
" # record current state\n", | |
" states.append(state)\n", | |
" # convert state to suitable input for model\n", | |
" state = np.array([state])\n", | |
" \n", | |
" # select action randomly with a distribution determined by the model output\n", | |
" action = np.random.choice(range(2), p=model(state).numpy().flatten())\n", | |
" # record action\n", | |
" actions.append(action)\n", | |
" \n", | |
" # calculate next state\n", | |
" nstate, reward, done, info = env.step(action)\n", | |
" \n", | |
" # record reward\n", | |
" temp_rewards.append(reward)\n", | |
" \n", | |
" state = nstate\n", | |
" \n", | |
" # record number of steps\n", | |
" steps.append(len(temp_rewards))\n", | |
" \n", | |
" # record discounted rewards\n", | |
" rewards += discount_rewards(temp_rewards)\n", | |
" \n", | |
" # check if training should happen\n", | |
" if episode % update_freq == 0:\n", | |
" # get states and rewards into suitable shape for model input\n", | |
" states = np.array(states)\n", | |
" rewards = np.array(rewards)\n", | |
" \n", | |
" # calculate a list of indexes of the responsible outputs for each action taken\n", | |
" idx = np.array(list(zip(range(len(actions)), actions)))\n", | |
"\n", | |
" # update the policy network\n", | |
" optimizer.minimize(loss, model.trainable_variables)\n", | |
" \n", | |
" # reset everything\n", | |
" states = []\n", | |
" rewards = []\n", | |
" actions = []" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<matplotlib.lines.Line2D at 0x7f0310569d30>]" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"plt.plot(steps)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<matplotlib.lines.Line2D at 0x7f031024aef0>]" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# plot smoothed out steps graph\n", | |
"avg = []\n", | |
"for i in range(100, len(steps)+1, 100):\n", | |
" avg.append(sum(steps[i-100:i]) / 100)\n", | |
" \n", | |
"plt.plot(range(0, num_episodes, 100), avg)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment