Last active
March 22, 2018 06:55
-
-
Save analyticsindiamagazine/381977a5831835b4c0dddbb53e6c1b70 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Define the Next State and Error Corretion Techniques\n", | |
"We will use Gradient optimizer to reduce the loss with a learning rate of 0.1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"Q2 = tf.placeholder(shape=[1,4],dtype=tf.float32)\n", | |
"loss = tf.reduce_sum(tf.square(Q2 - Q1))\n", | |
"gdo = tf.train.GradientDescentOptimizer(learning_rate=0.1)\n", | |
"updatedweights = gdo.minimize(loss)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Lets build our model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"gamma = 0.9\n", | |
"epsilon = 0.1\n", | |
"episodes = 2000\n", | |
"\n", | |
"totalReward = 0\n", | |
"\n", | |
"session = tf.Session()\n", | |
"session.run(tf.initialize_all_variables())\n", | |
"for i in range(episodes):\n", | |
" state_now = env.reset()\n", | |
" done = False\n", | |
" reward = 0\n", | |
" for j in range(100):\n", | |
" #Lets find the dot product of the inputs with the weights and apply argmax on it.\n", | |
" action , Y = session.run([output, Q1], feed_dict = {inputs : [np.eye(16)[state_now]]})\n", | |
" if epsilon > np.random.rand(1):\n", | |
" action[0] = env.action_space.sample()\n", | |
" epsilon -= 10**-3\n", | |
" #Lets iterate to the next state Note: This can be random.\n", | |
" state_next , reward, done, _ = env.step(action[0])\n", | |
" Y1 = session.run(Q1, feed_dict = {inputs : [np.eye(16)[state_next]]})\n", | |
" change_Y = Y\n", | |
" change_Y[0, action[0]] = reward + gamma*np.max(Y1)\n", | |
" #Updating the weights \n", | |
" _,new_weights = session.run([updatedweights,weights],feed_dict={inputs:[np.eye(16)[state_now]],Q2:change_Y})\n", | |
" #Lets append the total number of rewards\n", | |
" totalReward += reward\n", | |
" state_now = state_next\n", | |
" if reward == 1:\n", | |
" print ('Episode {} was successful, Agent reached the Goal'.format(i))\n", | |
" \n", | |
"session.close()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment