Created
November 14, 2016 21:57
-
-
Save awjuliani/a9aa513b6b91dbffaaf1b1e149ee5b32 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Simple Reinforcement Learning: Exploration Strategies\n", | |
"This notebook contains implementations of various action-selections methods that can be used to encourage exploration during the learning process. To learn more about these methods, see the accompanying [Medium post](https://medium.com/p/d3a97b7cceaf/). Also see the interactive visualization: [here](http://awjuliani.github.io/exploration/index.html).\n", | |
"\n", | |
"For more reinforcment learning tutorials see:\n", | |
"https://github.com/awjuliani/DeepRL-Agents" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import gym\n", | |
"import numpy as np\n", | |
"import random\n", | |
"import tensorflow as tf\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"\n", | |
"import tensorflow.contrib.slim as slim" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Load the environment" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"env = gym.make('CartPole-v0')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## The Deep Q-Network " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Helper functions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class experience_buffer():\n", | |
" def __init__(self, buffer_size = 10000):\n", | |
" self.buffer = []\n", | |
" self.buffer_size = buffer_size\n", | |
" \n", | |
" def add(self,experience):\n", | |
" if len(self.buffer) + len(experience) >= self.buffer_size:\n", | |
" self.buffer[0:(len(experience)+len(self.buffer))-self.buffer_size] = []\n", | |
" self.buffer.extend(experience)\n", | |
" \n", | |
" def sample(self,size):\n", | |
" return np.reshape(np.array(random.sample(self.buffer,size)),[size,5])\n", | |
" \n", | |
"def updateTargetGraph(tfVars,tau):\n", | |
" total_vars = len(tfVars)\n", | |
" op_holder = []\n", | |
" for idx,var in enumerate(tfVars[0:total_vars/2]):\n", | |
" op_holder.append(tfVars[idx+total_vars/2].assign((var.value()*tau) + ((1-tau)*tfVars[idx+total_vars/2].value())))\n", | |
" return op_holder\n", | |
"\n", | |
"def updateTarget(op_holder,sess):\n", | |
" for op in op_holder:\n", | |
" sess.run(op)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Implementing the network itself" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"class Q_Network():\n", | |
" def __init__(self):\n", | |
" #These lines establish the feed-forward part of the network used to choose actions\n", | |
" self.inputs = tf.placeholder(shape=[None,4],dtype=tf.float32)\n", | |
" self.Temp = tf.placeholder(shape=None,dtype=tf.float32)\n", | |
" self.keep_per = tf.placeholder(shape=None,dtype=tf.float32)\n", | |
"\n", | |
" hidden = slim.fully_connected(self.inputs,64,activation_fn=tf.nn.tanh,biases_initializer=None)\n", | |
" hidden = slim.dropout(hidden,self.keep_per)\n", | |
" self.Q_out = slim.fully_connected(hidden,2,activation_fn=None,biases_initializer=None)\n", | |
" \n", | |
" self.predict = tf.argmax(self.Q_out,1)\n", | |
" self.Q_dist = tf.nn.softmax(self.Q_out/self.Temp)\n", | |
" \n", | |
" \n", | |
" #Below we obtain the loss by taking the sum of squares difference between the target and prediction Q values.\n", | |
" self.actions = tf.placeholder(shape=[None],dtype=tf.int32)\n", | |
" self.actions_onehot = tf.one_hot(self.actions,2,dtype=tf.float32)\n", | |
" \n", | |
" self.Q = tf.reduce_sum(tf.mul(self.Q_out, self.actions_onehot), reduction_indices=1)\n", | |
" \n", | |
" self.nextQ = tf.placeholder(shape=[None],dtype=tf.float32)\n", | |
" loss = tf.reduce_sum(tf.square(self.nextQ - self.Q))\n", | |
" trainer = tf.train.GradientDescentOptimizer(learning_rate=0.0005)\n", | |
" self.updateModel = trainer.minimize(loss)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Training the network" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Set learning parameters\n", | |
"exploration = \"e-greedy\" #Exploration method. Choose between: greedy, random, e-greedy, boltzmann, bayesian.\n", | |
"y = .99 #Discount factor.\n", | |
"num_episodes = 20000 #Total number of episodes to train network for.\n", | |
"tau = 0.001 #Amount to update target network at each step.\n", | |
"batch_size = 32 #Size of training batch\n", | |
"startE = 1 #Starting chance of random action\n", | |
"endE = 0.1 #Final chance of random action\n", | |
"anneling_steps = 200000 #How many steps of training to reduce startE to endE.\n", | |
"pre_train_steps = 50000 #Number of steps used before training updates begin." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"tf.reset_default_graph()\n", | |
"\n", | |
"q_net = Q_Network()\n", | |
"target_net = Q_Network()\n", | |
"\n", | |
"init = tf.initialize_all_variables()\n", | |
"trainables = tf.trainable_variables()\n", | |
"targetOps = updateTargetGraph(trainables,tau)\n", | |
"myBuffer = experience_buffer()\n", | |
"\n", | |
"\n", | |
"#create lists to contain total rewards and steps per episode\n", | |
"jList = []\n", | |
"jMeans = []\n", | |
"rList = []\n", | |
"rMeans = []\n", | |
"with tf.Session() as sess:\n", | |
" sess.run(init)\n", | |
" updateTarget(targetOps,sess)\n", | |
" e = startE\n", | |
" stepDrop = (startE - endE)/anneling_steps\n", | |
" total_steps = 0\n", | |
" \n", | |
" for i in range(num_episodes):\n", | |
" s = env.reset()\n", | |
" rAll = 0\n", | |
" d = False\n", | |
" j = 0\n", | |
" while j < 999:\n", | |
" j+=1\n", | |
" if exploration == \"greedy\":\n", | |
" #Choose an action with the maximum expected value.\n", | |
" a,allQ = sess.run([q_net.predict,q_net.Q_out],feed_dict={q_net.inputs:[s],q_net.keep_per:1.0})\n", | |
" a = a[0]\n", | |
" if exploration == \"random\":\n", | |
" #Choose an action randomly.\n", | |
" a = env.action_space.sample()\n", | |
" if exploration == \"e-greedy\":\n", | |
" #Choose an action by greedily (with e chance of random action) from the Q-network\n", | |
" if np.random.rand(1) < e or total_steps < pre_train_steps:\n", | |
" a = env.action_space.sample()\n", | |
" else:\n", | |
" a,allQ = sess.run([q_net.predict,q_net.Q_out],feed_dict={q_net.inputs:[s],q_net.keep_per:1.0})\n", | |
" a = a[0]\n", | |
" if exploration == \"boltzmann\":\n", | |
" #Choose an action probabilistically, with weights relative to the Q-values.\n", | |
" Q_d,allQ = sess.run([q_net.Q_dist,q_net.Q_out],feed_dict={q_net.inputs:[s],q_net.Temp:e,q_net.keep_per:1.0})\n", | |
" a = np.random.choice(Q_d[0],p=Q_d[0])\n", | |
" a = np.argmax(Q_d[0] == a)\n", | |
" if exploration == \"bayesian\":\n", | |
" #Choose an action using a sample from a dropout approximation of a bayesian q-network.\n", | |
" a,allQ = sess.run([q_net.predict,q_net.Q_out],feed_dict={q_net.inputs:[s],q_net.keep_per:(1-e)+0.1})\n", | |
" a = a[0]\n", | |
" \n", | |
" #Get new state and reward from environment\n", | |
" s1,r,d,_ = env.step(a)\n", | |
" myBuffer.add(np.reshape(np.array([s,a,r,s1,d]),[1,5]))\n", | |
" \n", | |
" if e > endE and total_steps > pre_train_steps:\n", | |
" e -= stepDrop\n", | |
" \n", | |
" if total_steps > pre_train_steps and total_steps % 5 == 0:\n", | |
" #We use Double-DQN training algorithm\n", | |
" trainBatch = myBuffer.sample(batch_size)\n", | |
" Q1 = sess.run(q_net.predict,feed_dict={q_net.inputs:np.vstack(trainBatch[:,3]),q_net.keep_per:1.0})\n", | |
" Q2 = sess.run(target_net.Q_out,feed_dict={target_net.inputs:np.vstack(trainBatch[:,3]),target_net.keep_per:1.0})\n", | |
" end_multiplier = -(trainBatch[:,4] - 1)\n", | |
" doubleQ = Q2[range(batch_size),Q1]\n", | |
" targetQ = trainBatch[:,2] + (y*doubleQ * end_multiplier)\n", | |
" _ = sess.run(q_net.updateModel,feed_dict={q_net.inputs:np.vstack(trainBatch[:,0]),q_net.nextQ:targetQ,q_net.keep_per:1.0,q_net.actions:trainBatch[:,1]})\n", | |
" updateTarget(targetOps,sess)\n", | |
"\n", | |
" rAll += r\n", | |
" s = s1\n", | |
" total_steps += 1\n", | |
" if d == True:\n", | |
" break\n", | |
" jList.append(j)\n", | |
" rList.append(rAll)\n", | |
" if i % 100 == 0 and i != 0:\n", | |
" r_mean = np.mean(rList[-100:])\n", | |
" j_mean = np.mean(jList[-100:])\n", | |
" if exploration == 'e-greedy':\n", | |
" print \"Mean Reward: \" + str(r_mean) + \" Total Steps: \" + str(total_steps) + \" e: \" + str(e)\n", | |
" if exploration == 'boltzmann':\n", | |
" print \"Mean Reward: \" + str(r_mean) + \" Total Steps: \" + str(total_steps) + \" t: \" + str(e)\n", | |
" if exploration == 'bayesian':\n", | |
" print \"Mean Reward: \" + str(r_mean) + \" Total Steps: \" + str(total_steps) + \" p: \" + str(e)\n", | |
" if exploration == 'random' or exploration == 'greedy':\n", | |
" print \"Mean Reward: \" + str(r_mean) + \" Total Steps: \" + str(total_steps)\n", | |
" rMeans.append(r_mean)\n", | |
" jMeans.append(j_mean)\n", | |
"print \"Percent of succesful episodes: \" + str(sum(rList)/num_episodes) + \"%\"" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Some statistics on network performance" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"plt.plot(rMeans)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"plt.plot(jMeans)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment