Last active
March 4, 2018 01:43
-
-
Save rajshah4/aa6c67944f4a43a7c9a1204301788e0c to your computer and use it in GitHub Desktop.
RNN_Addition_1stgrade
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Teaching a computer to add (using memorization)**\n", | |
"The goal here is to take advantage of Recurrent Neural Networks, for more background see my blog post at http://projects.rajivshah.com/blog/2016/04/05/rnn_addition/ This code was partially derived from https://github.com/yankev/tensorflow_example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"#Import basic libraries\n", | |
"import numpy as np\n", | |
"import tensorflow as tf\n", | |
"#from tensorflow.models.rnn import rnn_cell\n", | |
"#from tensorflow.models.rnn import rnn\n", | |
"#from tensorflow.models.rnn import seq2seq\n", | |
"from numpy import sum\n", | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#Defining some hyper-params\n", | |
"num_units = 50 #this is the parameter for input_size in the basic LSTM cell\n", | |
"input_size = 1 \n", | |
"batch_size = 50 \n", | |
"seq_len = 15\n", | |
"drop_out = 0.6 " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"#Creates our random sequences\n", | |
"def gen_data(min_length=5, max_length=15, n_batch=50):\n", | |
"\n", | |
" X = np.concatenate([np.random.randint(10,size=(n_batch, max_length, 1))],\n", | |
" axis=-1)\n", | |
" y = np.zeros((n_batch,))\n", | |
" # Compute masks and correct values\n", | |
" for n in range(n_batch):\n", | |
" # Randomly choose the sequence length\n", | |
" length = np.random.randint(min_length, max_length)\n", | |
" X[n, length:, 0] = 0\n", | |
" # Sum the dimensions of X to get the target value\n", | |
" y[n] = np.sum(X[n, :, 0]*1)\n", | |
" return (X,y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"### Model Construction\n", | |
"num_layers = 2\n", | |
"cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)\n", | |
"cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)\n", | |
"cell = tf.nn.rnn_cell.DropoutWrapper(cell,output_keep_prob=drop_out)\n", | |
"\n", | |
"#create placeholders for X and y\n", | |
"inputs = [tf.placeholder(tf.float32,shape=[batch_size,1]) for _ in range(seq_len)]\n", | |
"result = tf.placeholder(tf.float32, shape=[batch_size])\n", | |
"initial_state = cell.zero_state(batch_size, tf.float32)\n", | |
"\n", | |
"outputs, states = tf.nn.seq2seq.rnn_decoder(inputs, initial_state, cell, scope ='rnnln')\n", | |
"outputs2 = outputs[-1]\n", | |
"\n", | |
"W_o = tf.Variable(tf.random_normal([num_units,input_size], stddev=0.01)) \n", | |
"b_o = tf.Variable(tf.random_normal([input_size], stddev=0.01))\n", | |
"\n", | |
"outputs3 = tf.matmul(outputs2, W_o) + b_o\n", | |
"\n", | |
"cost = tf.pow(tf.sub(tf.reshape(outputs3, [-1]), result),2)\n", | |
"\n", | |
"train_op = tf.train.RMSPropOptimizer(0.005, 0.2).minimize(cost) \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"### Generate Validation Data\n", | |
"tempX,y_val = gen_data(5,seq_len,batch_size)\n", | |
"X_val = []\n", | |
"for i in range(seq_len):\n", | |
" X_val.append(tempX[:,i,:])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"collapsed": false | |
}, | |
"source": [ | |
"##Run this cell to see what the inputs look like \n", | |
"print (tempX[1]) \n", | |
"print (y_val[1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"##Session\n", | |
"sess = tf.Session()\n", | |
"sess.run(tf.initialize_all_variables())\n", | |
"train_score =[]\n", | |
"val_score= []\n", | |
"x_axis=[]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"num_epochs=1000\n", | |
" \n", | |
"for k in range(1,num_epochs):\n", | |
"\n", | |
" #Generate Data for each epoch\n", | |
" tempX,y = gen_data(5,seq_len,batch_size)\n", | |
" X = []\n", | |
" for i in range(seq_len):\n", | |
" X.append(tempX[:,i,:])\n", | |
"\n", | |
" #Create the dictionary of inputs to feed into sess.run\n", | |
" temp_dict = {inputs[i]:X[i] for i in range(seq_len)}\n", | |
" temp_dict.update({result: y})\n", | |
"\n", | |
" _,c_train = sess.run([train_op,cost],feed_dict=temp_dict) #perform an update on the parameters\n", | |
"\n", | |
" val_dict = {inputs[i]:X_val[i] for i in range(seq_len)} #create validation dictionary\n", | |
" val_dict.update({result: y_val})\n", | |
" c_val = sess.run([cost],feed_dict = val_dict ) #compute the cost on the validation set\n", | |
" if (k%100==0):\n", | |
" train_score.append(sum(c_train))\n", | |
" val_score.append(sum(c_val))\n", | |
" x_axis.append(k)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Final Train cost: 3086.54125977, on Epoch 999\n", | |
"Final Validation cost: 2445.63671875, on Epoch 999\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEACAYAAAC+gnFaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XmYVNW19/HvYhQBRcCgzKioYEwYoqBG08HhYmJABQWN\nwBPQKDjeq+aCyY2Ywau+McbcRIwRo+AQGRQVEUGlHRkMgqCI4AAKAkEUEFHpptf7xz5FF031XN2n\nquv3eZ566vSuc6pWkbhX7eHsbe6OiIjkrnpxByAiIvFSIhARyXFKBCIiOU6JQEQkxykRiIjkOCUC\nEZEcV2YiMLP9zGyhmS01s7fMbHxUPt7M1pnZkuhxZtI148xstZmtNLMzksp7m9ny6LU7k8obm9mj\nUfkCM+tUA99TRERKUWYicPevgR+6ew+gB9DfzPoADvzR3XtGj2cAzKw7MAToDvQH7jIzi95uAjDK\n3bsCXc2sf1Q+CtgSld8B3JrerygiImUpt2vI3XdGh42AhoQkAGApTh8IPOLuBe6+BngP6GNmhwLN\n3X1RdN4k4OzoeADwQHQ8HTi1sl9CRESqrtxEYGb1zGwpsAmYk1SZX2lmb5rZRDNrEZW1BdYlXb4O\naJeifH1UTvT8MYC7FwLbzKxlVb+QiIhUTkVaBEVR11B7wq/7YwjdPF0I3UUbgNtrNEoREakxDSp6\nortvM7N5QH9331Pxm9m9wFPRn+uBDkmXtSe0BNZHxyXLE9d0BD4xswbAge7+WcnPNzMtiiQiUknu\nnqobfy/lzRpqnej2MbMmwOnAO2Z2SNJp5wDLo+MngaFm1sjMugBdgUXuvhHYbmZ9osHjYcATSdeM\niI4HA8+X8YUy+nHjjTfGHoPiVJyKU3EmHhVVXovgUOABM6tPSBqPuvssM5tkZj0IA8cfApdGFfUK\nM5sCrAAKgTFeHM0Y4H6gCTDL3WdH5ROByWa2GtgCDK1w9CIiUm1lJgJ3Xw70SlE+vIxrbgZuTlG+\nGDg2Rfk3wPkVCVZERNJPdxanUV5eXtwhVIjiTC/FmV6Ks/ZZZfqR4mRmni2xiohkAjPDqztYLCIi\ndZ8SgYhIjlMiEBHJcUoEIiI5TolARCTHKRGIiOQ4JQIRkRynRCAikuOUCEREcpwSgYhIjlMiEBHJ\ncUoEIiI5TolARCTHKRGIiOQ4JQIRkRynRCAikuOUCEREcpwSgYhIjsuuRLBjR9wRiIjUOdmVCGbN\nijsCEZE6J7sSwfTpcUcgIlLnlJkIzGw/M1toZkvN7C0zGx+VtzSzuWa2yszmmFmLpGvGmdlqM1tp\nZmcklfc2s+XRa3cmlTc2s0ej8gVm1qnUgJ59Fr76qjrfV0RESigzEbj718AP3b0H0APob2Z9gLHA\nXHc/Eng++hsz6w4MAboD/YG7zMyit5sAjHL3rkBXM+sflY8CtkTldwC3lhpQr14hGYiISNqU2zXk\n7jujw0ZAQ8CBAcADUfkDwNnR8UDgEXcvcPc1wHtAHzM7FGju7oui8yYlXZP8XtOBU0sNZtAgdQ+J\niKRZuYnAzOqZ2VJgEzAnqszbuPum6JRNQJvouC2wLunydUC7FOXro3Ki548B3L0Q2GZmLVMGc845\nMHMmfPNNBb6aiIhURIPyTnD3IqCHmR0IPG5m3y7xupuZ11SAycbfcw80bw4jR5J3ySXk5eXVxseK\niGSF/Px88vPzK31duYkgwd23mdk84D+ATWZ2iLtvjLp9/h2dth7okHRZe0JLYH10XLI8cU1H4BMz\nawAc6O6fpYrhV78aT4MWLWDZMlASEBHZS15e3l4/kG+66aYKXVferKHWiRlBZtYEOB14B3gSGBGd\nNgKYER0/CQw1s0Zm1gXoCixy943AdjPrEw0eDwOeSLom8V6DCYPPKc2cCZx7Ljz5JBQUVOgLiohI\n2cobIzgUeMHM3gQWEcYIZgG3AKeb2SqgX/Q37r4CmAKsAJ4Bxrh7ottoDHAvsBp4z91nR+UTgVZm\nthq4hmgGUioTJgAdO8Jhh8GLL1b6y4qIyL6suJ7ObGbmrVs78+fDEY/dBh98AHffHXdYIiIZy8xw\ndyv3vGxKBNdd55jBbZe+DyeeCJ98AvXrxx2aiEhGqmgiyKolJi69FO6/H75udzi0bQuvvBJ3SCIi\nWS+rEsERR0DPnjBtGrq5TEQkTbIqEQCMHh0NGicSQVFR3CGJiGS1rEsEZ50Fa9fCsoJu0KIFLFwY\nd0giIlkt6xJBgwZwySUlWgUiIlJlWTVrKBHr+vXw7W/DR7PeovkFZ8GHH4KVOzAuIpJT6uSsoYR2\n7aBfP3ho6THQsCG88UbcIYmIZK2sTAQQDRrfbfigwdE0IhERqYqsTQT9+oXNyuZ3HR7GCbKki0tE\nJNNkbSKoVw8uuwwm5B8Nu3bBW2/FHZKISFbKysHihC1b4PDD4b0Lf03rgw0quOSqiEguqNODxQmt\nWsHAgfAPG6lppCIiVZTViQDCoPHfnu1E0WdbYeXKuMMREck6WZ8I+vSB5s2N544bp1aBiEgVZH0i\nMIumkn42RNNIRUSqIKsHixN27ICOHZ1l9XrSftFjYQczEZEclxODxQnNmsGFFxp/7/Q7dQ+JiFRS\nnWgRQLiN4D/yvmZNl340fP21WoxMRCQz5VSLAMIidIcd3YgnVx4JH30UdzgiIlmjziQCgNFj6jGh\n+S/gscfiDkVEJGvUqUQwaBAs/+pwVj24KO5QRESyRp1KBI0bw88urs/f3joRNmyIOxwRkaxQZiIw\nsw5mNs/M3jazt8zsqqh8vJmtM7Ml0ePMpGvGmdlqM1tpZmcklfc2s+XRa3cmlTc2s0ej8gVm1qk6\nX+jSMQ2YxHC+evTJ6ryNiEjOKK9FUAD8p7sfA/QFLjezboADf3T3ntHjGQAz6w4MAboD/YG7zPZs\nHTYBGOXuXYGuZtY/Kh8FbInK7wBurc4X6tIFjjtmJ1Pu2VqdtxERyRllJgJ33+juS6PjHcA7QLvo\n5VRTkgYCj7h7gbuvAd4D+pjZoUBzd0903k8Czo6OBwAPRMfTgVOr+F32GD3uICas6gebN1f3rURE\n6rwKjxGYWWegJ7AgKrrSzN40s4lm1iIqawusS7psHSFxlCxfT3FCaQd8DODuhcA2M2tZua+xtx+d\n05hPGnVhyf+9Up23ERHJCQ0qcpKZNQOmAVe7+w4zmwD8Jnr5t8DthC6eGjV+/Pg9x3l5eeTl5aU8\nr359+PlPNjDhH4255zcpTxERqXPy8/PJz8+v9HXl3llsZg2BmcAz7v6nFK93Bp5y92PNbCyAu98S\nvTYbuBFYC8xz925R+QXAKe4+OjpnvLsvMLMGwAZ3PzjF55R5Z3FJG9//km5H7GLNGjiw00EVvk5E\npK5Iy53F0UDvRGBFchKI+vwTzgGWR8dPAkPNrJGZdQG6AovcfSOw3cz6RO85DHgi6ZoR0fFg4Ply\nv10FHHJ4U05vu4IH/2dVOt5ORKTOKq9r6CTgImCZmS2Jym4ALjCzHoTZQx8ClwK4+wozmwKsAAqB\nMUk/48cA9wNNgFnuPjsqnwhMNrPVwBZgaDq+GMDoETu58i9HMsbDctUiIrKvOrPoXCr++Va6t97E\nPU+35+T+TWsoMhGRzJRzi86lYge14LKj8pnwW00jFREpTZ1OBADDRzdl1usH8+9/xx2JiEhmqvOJ\n4KALz+RcHuO+u3fFHYqISEaq84mAVq0Y3WMBf/trAbt3xx2MiEjmqfuJADhu5LG0KtzEnDlxRyIi\nknlyIhFwzjmM/uoOJvxVTQIRkZJyIxG0acPQXqt49aXd2sVSRKSE3EgEQNMhZ3FRu3zuuSfuSERE\nMkudvqFsL+vXs6L7YE7d/zXWrjUaNUpfbCIimUg3lJXUrh3du8NRB3/OjBlxByMikjlyJxEADB7M\n6NZTmTAh7kBERDJH7nQNAXz4IbuOO4mODdaTn28cfXR6YhMRyUTqGkqlSxcadW7LqFPXcvfdcQcj\nIpIZcisRAAwaxM/rT+TBB2HnzriDERGJX04mgk5z7+WEvs4//xl3MCIi8cu9RHDkkXDwwVx2ygoN\nGouIkIuJAGDQIPqvn8jmzfCvf8UdjIhIvHIzEQweTP3Hp3Hpz12tAhHJebmZCLp3h/33Z1SvJTz2\nGGzdGndAIiLxyc1EYAaDB/OtF/5J//4waVLcAYmIxCc3EwHAoEEwfTqjL3Puvhuy5L46EZG0y91E\n0KMHACc3X4oZvPhizPGIiMQkdxOBGQwahD02ncsuQ4PGIpKzykwEZtbBzOaZ2dtm9paZXRWVtzSz\nuWa2yszmmFmLpGvGmdlqM1tpZmcklfc2s+XRa3cmlTc2s0ej8gVm1qkmvmhKgwbBtGkMH+bMmQMb\nN9baJ4uIZIzyWgQFwH+6+zFAX+ByM+sGjAXmuvuRwPPR35hZd2AI0B3oD9xlZokFjyYAo9y9K9DV\nzPpH5aOALVH5HcCtaft25Tn+eNi5kwPXr2DwYJg4sdY+WUQkY5SZCNx9o7svjY53AO8A7YABwAPR\naQ8AZ0fHA4FH3L3A3dcA7wF9zOxQoLm7L4rOm5R0TfJ7TQdOre6XqjAzOPfcMGg8Gu65B3ZrW2MR\nyTEVHiMws85AT2Ah0MbdN0UvbQLaRMdtgXVJl60jJI6S5eujcqLnjwHcvRDYZmYtK/MlqmXwYJg2\njV694JBD4Jlnau2TRUQyQoOKnGRmzQi/1q929y+Ke3vA3d3MamXy5fjx4/cc5+XlkZeXV/03PfFE\n2LwZVq9m9OiuTJgAZ51V/bcVEalt+fn55OfnV/q6cjemMbOGwEzgGXf/U1S2Eshz941Rt888dz/a\nzMYCuPst0XmzgRuBtdE53aLyC4BT3H10dM54d19gZg2ADe5+cIo4qr8xTWnGjIGOHfnq6rF06ACv\nvw5dutTMR4mI1Ja0bEwTDfROBFYkkkDkSWBEdDwCmJFUPtTMGplZF6ArsMjdNwLbzaxP9J7DgCdS\nvNdgwuBz7Yq6h5o0gWHDwliBiEiuKLNFYGbfB14ClgGJE8cBi4ApQEdgDXC+u2+NrrkBGAkUErqS\nno3KewP3A02AWe6emIraGJhMGH/YAgyNBppLxlJzLYLCQjj0UHj9dd79pjOnnAIffQSNG9fMx4mI\n1IaKtghya8/islxyCRx9NFx7LaeeChdfDBdcUHMfJyJS07RncWVFaw8BjB6tO41FJHeoRZCwa1eY\nP7psGQVt2tOpE8ydC8ccU3MfKSJSk9QiqKxGjeAnP4HHH6dhw9A1dPfdcQclIlLzlAiSJXUPXXIJ\nPPQQ7NgRc0wiIjVMiSDZGWfA0qWwaRMdOsDJJ8Mjj8QdlIhIzVIiSLbffnDmmTAj3BaRGDTOkmEU\nEZEqUSIoKbq5DEIDYds2WLSonGtERLKYZg2V9OWX0LYtfPABtGrFbbfBihVw//01/9EiIumkWUNV\n1bQpnH46PBFWwPjZz0JP0WefxRyXiEgNUSJIJWn20MEHh9VIH3ignGtERLKUuoZS2b4d2reHjz+G\nAw/k1Vdh5EhYuTLsZSMikg3UNVQdBxwAeXnw1FNA2LKgUSN44YV4wxIRqQlKBKVJ6h4y0/pDIlJ3\nqWuoNJ9/Dp06wSefQLNmbN8e/nz77TCpSNJv1y5YsgRefhleeSWs8/T738cdlUj2UtdQdR10UOgT\nmjULCL1FQ4bAvffGHFcdsn07zJkD//M/8MMfQsuWcOmlsGZNuJ3j3nvhrbfijlKk7lOLoCx//zs8\n9xw8+igAb74ZZhB9+CE0qNBuz5Lsk0/CL/3EY9Uq6N0bvv/98DjhBGjRovj8O+8MieLpp+OLWSSb\naWOadNi8GY44AjZuhCZNgNBI+O//hoEDazeUbOMO775b3M3zyiuwdSucdFJxxd+7d9m7wO3aBd26\nhXzcr1/txS5SVygRpEu/fnDVVXD22QBMnhxWJZ09u/ZDyWQl+/dffRWaNSuu9L///VCp16tkZ+SU\nKXDrrfD665W/ViTXKRGky113wWuvwYMPAvD119ChAyxYAIcfXvvhZIrt28O/QaLif/310HhKrvjb\nt6/+57hD375w9dVw4YXVfz+RXKJEkC4bNkD37qF7KOrHuO668Ov0tttqP5y4bNgQKvxExV9e/346\nvfQSDB8ebujbb7+a+QyRukiJIJ1OPhnGjoUf/xiA1atDX/dHH9XNiinRv59c8Ve2fz/dBg6EU06B\na6+tvc8UyXZKBOn0pz/BsmVw3317is44I/xKveiieEJKp0T/fqLiT1f/fjqtXBny8bvvhmmmIlI+\nJYJ0+ugj6NUr9I80bAjA44/D7beHyjPbJPr3ExV/TfXvp9tll4UE9Yc/xB2JSHZIWyIws/uAHwP/\ndvdjo7LxwMXA5ui0G9z9mei1ccBIYDdwlbvPicp7A/cD+wGz3P3qqLwxMAnoBWwBhrj72hRxxJcI\nAPr0gd/9LixRDRQWQufO8MwzcOyx8YVVHvfQn79gAcyfH57fe6/2+vfTaePGcLfx4sXh315EypbO\nRHAysAOYlJQIbgS+cPc/lji3O/AwcBzQDngO6OrubmaLgCvcfZGZzQL+7O6zzWwM8G13H2NmQ4Bz\n3H1oijjiTQS33RY2q7n77j1FN90EmzaFiUWZYvv2sKNaotJfsCD8ij7hhPDo2xd69Kjd/v10uumm\nkNgeeijuSEQyX1q7hsysM/BUiUSww91vL3HeOKDI3W+N/p4NjAfWAi+4e7eofCiQ5+6XRefc6O4L\nzawBsMHdD04RQ7yJ4P33w91kn3wC9esDsH59aA2sXQvNm9d+SEVFoc88UenPnx/ueu7Zs7jS79u3\nbq2NtGMHHHlkWBi2d++4oxHJbBVNBNVZKOFKMxsO/Au41t23Am2BBUnnrCO0DAqi44T1UTnR88cA\n7l5oZtvMrKW7Z9aeYIcfHmrUV16BH/wAgHbtwmrVDz0U+q9r2tatsHBhcaW/cGFYEilR6f/85/Cd\n74Qls+uqZs3gxhvh+uvh+ee1P4RIOlQ1EUwAfhMd/xa4HRiVlojKMH78+D3HeXl55OXl1fRH7i2x\nsX2UCCAsT33ddWGxtHRWSkVFYa/k5L79jz4Kv4L79g2fe//9cMgh6fvMbDFqVJjI9cwz8KMfxR2N\nSObIz88nPz+/0tdVqWuotNfMbCyAu98SvTYbuJHQNTQvqWvoAuAUdx+d6D5y9wUZ3TUEYQ7jqaeG\nncuiuZRFRXDUUWEryxNPrPpbf/ZZ+IWfqPQXLQrbZPbtW/yL/zvf0WJ3CU8+CTfcEBYCjHrqRKSE\nGl2G2swOTfrzHGB5dPwkMNTMGplZF6ArsMjdNwLbzayPmRkwDHgi6ZoR0fFg4PmqxFQrjj46TK9Z\nuHBPUb16oVsoaQy5XLt3h9sS/vY3+NnPwtt27gz/7/+F1668Mty0tnp1WNtozJgwe1VJoNhPfgKt\nWoVWkYhUT0VmDT0C/ABoDWwi/MLPA3oADnwIXOrum6LzbyBMHy0Ernb3Z6PyxPTRJoTpo1dF5Y2B\nyUBPwvTRoe6+JkUc8bcIIHRQ79gRbiKIbNkS5uG/916onEr69NPiGTzz54d5+4ceuvdMnm9/W79s\nK2vRIjj33DBg3rRp3NGIZB7dUFZTli2DAQPC9JykQYERI8IMomuuCZupzJ9f3M2zaRMcf3xxpd+n\nT+qEIZU3ZEjoMvvlL+OORCTzKBHUFPcwKPDww/C97+0pXrBgz71mdOhQ3Ld/wglheQb92q8Z778f\nEuuKFfCtb8UdjUhmUSKoSePGhef//d+9ihNLNRx0UAwx5bBrrgl3ev/lL3FHIpJZlAhq0uLFMHRo\nuMVVE9nTr7AwrD194IEVumvs009Dq+vVV8PNZiISaPP6mtSrV6isli8v/1ypmMLCcIfYZZeFO/Wu\nuw7694d33in30tatw+mJhpqIVI4SQVWYhekq06fHHUl2KyyEuXPDLdFt24bNoLt0CTvCvfFG6Hob\nPBi+/LLct7rqqtA199prtRC3SB2jrqGqmj8fLr4Y3n477kiyS0EBzJsHU6fCjBmh4j/vvFDhd+my\n97nu4UaLwsJwQ0U53XCTJoX7OV59VT12IqAxgppXVBSmBz33XOigltIVFIRun2nTQuV/+OHFlX95\n60nv3BmmYF1+eVjHowy7d4chhV//OjTYRHKdEkFtuOqqMGfxV7+KO5LMs2tXqPynToUnngijuOed\nB4MGQadOlXuvVavCPpmzZ5c7eDxnDlxxRWioRXsIieQsJYLa8OKLYe7ikiVxR5IZdu0KLaSpU8Ni\nQEcdVVz5d+xYvfeeOjWMISxeXO783P/4j3DP3+WXV+8jRbKdEkFt2L07DHK+9lro7shF33wTBnyn\nTg2bBHTrVlz5d+iQ3s+65pqwOdCMGWVuoLx0aZhwtGoVHHBAekMQySZKBLXlssvgsMPgF7+IO5La\n8803oQ9m6lSYOTPsH5mo/Nu1K//6qtq1KywBfs455f57jxgR8tDvfldz4YhkOiWC2jJ3bhgjSFqR\ntE76+mt49tkw4DtzZlhYKVH51+YWaB9/DMcdB48+ute+EKlO69EjLA1Vk7lJJJMpEdSWgoKwlOgb\nb1S/HzzTfP11GKCdOhVmzYLvfjdU/ueeG75zXJ59FkaODOMFZezMM3YsbN4MEyfWYmwiGUSJoDaN\nHBmWwLzmmrgjqb6vvtq78u/Zs7jyz6Tt0MaPD4P1c+eWulHDtm1hstJzz4UGjEiuUSKoTbNmwc03\nh/2Ms9HOnWHfx2nTwnOvXsWVf5s2cUeX2u7dcOaZYQXYm28u9bQ77wzDGU8/XYuxiWQIJYLa9M03\n4dfyihXxdplUxs6dIYFNnRq6Wr73vVD5n3NO9qznvHlzSFoTJsBZZ6U8ZdeuMJHp73+Hfv1qOT6R\nmCkR1LaLLgqbFo8ZUzPvX1QUEs4334S++6+/rthxqtfWrAk/k48/vrjyP3ifbaKzw6uvhpbLggX7\nLlERmTIFbr01rEVUxqxTkTpHiaC2zZgBv/89XHtt9Srp0o537YLGjWG//Yqfq3rcpg38+Mdh2c66\n4I474KGHQlJo3Hifl93D5jVXXw0//WkM8YnERImgtn31Vdh1fseOsivjqlbijRppJbXSuId1i9q0\ngbvuSnnKSy/B8OGwcmX45xTJBUoEklu2bQvjHDfdBBdemPKUgQPhlFNCo00kFygRSO5580047bQw\nrbR7931efuedkAjefRdatowhPpFaph3KJPd897tw222hm2jHjn1e7tYt3AhdxmxTkZykFoHUPaNG\nhTGbhx7aZ1xl48awNNLixeVvhSCS7dLWIjCz+8xsk5ktTypraWZzzWyVmc0xsxZJr40zs9VmttLM\nzkgq721my6PX7kwqb2xmj0blC8yskovVi5Twl7+EDQnuvnuflw45JIzp//KXMcQlkqEq0jX0D6B/\nibKxwFx3PxJ4PvobM+sODAG6R9fcZbbnJ9kEYJS7dwW6mlniPUcBW6LyO4Bbq/F9RKBJk3CX9I03\nhpsHSrjuurBb5uLFMcQmkoHKTQTu/jLweYniAcAD0fEDwNnR8UDgEXcvcPc1wHtAHzM7FGju7oui\n8yYlXZP8XtOBU6vwPUT21rVraBGcfz589tleLzVrFnLE9deHmaciua6qg8Vt3H1TdLwJSCxI0xZY\nl3TeOqBdivL1UTnR88cA7l4IbDMzzemQ6jv33PAYPjzcmZ1k1CjYsCEsrSSS61Iv21gJ7u5mViu/\nq8aPH7/nOC8vj7y8vNr4WMlmt9wCeXlhjYlx4/YUN2gQin7xi7C1Zf368YUoki75+fnk5+dX+roK\nzRoys87AU+5+bPT3SiDP3TdG3T7z3P1oMxsL4O63ROfNBm4E1kbndIvKLwBOcffR0Tnj3X2BmTUA\nNrj7PgvfaNaQVNm6dWEzm4cfhh/+cE+xe9jbZsSI0EIQqWtq+j6CJ4ER0fEIYEZS+VAza2RmXYCu\nwCJ33whsN7M+0eDxMOCJFO81mDD4LJI+7dvDpElhoaENG/YUm8Ef/hDGC778Msb4RGJWbovAzB4B\nfgC0JowH/JpQiU8BOgJrgPPdfWt0/g3ASKAQuNrdn43KewP3A02AWe5+VVTeGJgM9AS2AEOjgeaS\ncahFINXzm9/A88+HR9JmNkOGhI1rfvWrGGMTqQFaYkKkpKIi+NGPwmbGt9yyp/j998PqpCtWZM9W\nDCIVoUQgksqnn0Lv3vB//wcDBuwpvuYaKCwM96KJ1BVKBCKlWbAgLEU6fz4cdhgQ8kO3bmFLgyOP\njDk+kTTRonMipenbN6wxcd55YeMfwh4911231wxTkZyhFoHkJvcwStyy5Z41ib76Co46Ch55BE46\nKeb4RNJALQKRspjBvfeGRYcmTwbCEkW//a2WnpDco0QgueuAA8LidP/1X/DWWwBcdBHs3AmPPx5z\nbCK1SIlActuxx8Ltt4fNbL74gvr1w942Y8dCQUHcwYnUDiUCkeHDwx6WF18M7pxxRti05p574g5M\npHZosFgEwuyhE0+EkSPhiitYuhT694dVq0IPkkg20n0EIpX1/vtwwgkwcyYcfzwjRkCHDvC738Ud\nmEjVKBGIVMWMGeE248WL+XhnK3r0gGXLoF278i8VyTRKBCJVdf31Yc/jmTMZe0M9Nm+GiRPjDkqk\n8pQIRKqqoAD69YP+/dl6+S856ih47rkwwUgkmygRiFTHJ5/A974Hkydz51unMmcOPP103EGJVI7u\nLBapjrZtwx3Hw4YxesB6Vq6EF16IOyiRmqEWgUhZfv97mD2bRy+bx21/bMDrr0M9/XySLKEWgUg6\njBsHzZtz/pJx1K8fFqQTqWvUIhApz5Yt0Ls3L108ieH3nsLKlbDffnEHJVI+tQhE0qVVK5gyhVP+\nPJjvHvElf/1r3AGJpJdaBCIV9de/8s5fnueUT6fz7rtGy5ZxByRSNrUIRNJtzBi69WjMua1f5uab\n4w5GJH3UIhCpjC++YGOvH3HMhrn8a/l+dOkSd0AipVOLQKQmNG/OITPu5sqiP/OrK7bGHY1IWlQr\nEZjZGjNbZmZLzGxRVNbSzOaa2Sozm2NmLZLOH2dmq81spZmdkVTe28yWR6/dWZ2YRGrcMcdw3Z87\nMu/ZXSwR0C9QAAAKxElEQVR+cUfc0YhUW3VbBA7kuXtPdz8+KhsLzHX3I4Hno78xs+7AEKA70B+4\ny8wSTZYJwCh37wp0NbP+1YxLpEY1u3goN544l+vPW4MXqctSsls6uoZK9j8NAB6Ijh8Azo6OBwKP\nuHuBu68B3gP6mNmhQHN3XxSdNynpGpGMNWrWIDZsb8pN332M+XctYeeOorhDEqmSdLQInjOzf5nZ\nJVFZG3ffFB1vAtpEx22BdUnXrgPapShfH5WLZLQGzfbj4Vkt+GT/I7ji2sa0PuAbvn3wJoYP3Mqd\nd8LLL8MXX8QdpUj5GlTz+pPcfYOZHQzMNbOVyS+6u5tZ2trN48eP33Ocl5dHXl5eut5apEp69juI\nexYeBO7sWryct+98jsVPfcIb8/vwyP7fZ/m/D6FDR6NXL+jVC3r3hp49oUWL8t9bpLLy8/PJz8+v\n9HVpmz5qZjcCO4BLCOMGG6Nun3nufrSZjQVw91ui82cDNwJro3O6ReUXAD9w98tKvL+mj0p22L07\nLFX64IMUzHiald85nzeOGcbiet/jjWUNefNNaNOGfZJD69ZxBy51TY3vR2Bm+wP13f0LM2sKzAFu\nAk4Dtrj7rVHl38Ldx0aDxQ8DxxO6fp4DjohaDQuBq4BFwNPAn919donPUyKQ7PPll2H7ywcfhAUL\nYMAAdl84jFXtfsgbb9bnjTdg8WJYsiS0Enr33jtBtGlT/keIlKY2EkEX4PHozwbAQ+7+v2bWEpgC\ndATWAOe7+9bomhuAkUAhcLW7PxuV9wbuB5oAs9z9qhSfp0Qg2W3jRvjnP8M+Bxs3woUXwrBh8J3v\nUFQEH3wQkkIiObzxBjRpUpwUEs9t24KV+5925isshM8/D2v6JR7btkGfPnDkkXFHVzdohzKRTLZi\nRWglPPhgaAoMGxYSQ7vieRLusHZtcVJIJAizvZNDr17QqVO8yWHnzr0r9MTj009Tl2/ZEgbSDzww\nrOnXqlXoGmvaFF56Cb71LTjvvPBQUqg6JQKRbFBUFKYXTZ4Mjz0WavWLLoJBg6B5831Od4f16/du\nNSxeDN98s29yOPzwyieHoqJ9f6VX5AHFFXpFHy1aQP36+8aweze8+ipMnQrTpoXusfPPD0mha9cq\n/BvnMCUCkWzz1Vcwc2ZICi++CD/+cUgKZ5wBDcqe4LdhQxhnSG49bNsWBqETiaFx4/Ir9K1bQ/4p\nWWm3bl12pb7//jXzT7J7N7zySnFSOPTQ4qRwxBE185l1iRKBSDb79FN49NHQdfTBBzB0aOg+6t27\nwj/zN28OySGRGHbvLv9X+kEHlZtzYpNIClOmwPTpYawk0X2kpJCaEoFIXbF6dfF4QqNGoZXw059C\n585xRxab3btDj9rUqcVJIdFSOPzwuKPLHEoEInWNO8yfHxLClCnQvXtoJZx3Xk7foZZIClOmhGGW\ndu2KWwq5nhSUCETqsl27YNaskBTmzoXTTw9J4cwzQ6shR+3eHWYdJVoK7dsXtxQOOyzu6GqfEoFI\nrvj88zCSOnlymJZ6/vmh++iEE+rGDQdVlEgKiZZChw7FLYU6kxTc4euvw1zcFA8bOlSJQCTnfPgh\nPPxwSAqFhcXjCTk+77KwcO+WQseOxS2FWt9lrrBw30p7+/ZSK/NyHw0ahKleKR42daoSgUjOcg9z\nSSdPDnczd+kCRx+99+slz6/o39W5trz3gtCKqV8f6tVL/VzWa+U9169PodfnxQ86MHXJETy2tAud\nWu3g/OPXcl6fj+h8yNcV+1wIy4dUpeIuKIBmzfatuA84oNQKvcxHw4b7/hvu+adU15CIQKh4Xngh\n3GyQrGS3UWX+rs615b1XUVF47N6973Oqsoo+pygrLIQX1x/BlA+P4/GPetG56aec134+57V7jc77\nbSz92qKicBt0VSrwJk1qrctOiUBEpBIKCyE/P3QfPfZYaEQluo86dYo7utK5hzvLd+wIjZTk51NP\nVSIQEamSRFKYMgUefzwMLicGmquaFNzDzeOpKuzEc1mvlXVuw4ahgdKs2d7P+flKBCIi1VZQUNxS\nSCSFAQPCkh2VqbC//DJc06zZvhV2qufKnFPa3eDqGhIRSbOCApg3D2bPDt38lanImzZNvcheTVIi\nEBHJcRVNBNXdvF5ERLKcEoGISI5TIhARyXFKBCIiOU6JQEQkxykRiIjkOCUCEZEclzGJwMz6m9lK\nM1ttZv8ddzwiIrkiIxKBmdUH/gL0B7oDF5hZt3ijqrz8/Py4Q6gQxZleijO9FGfty4hEABwPvOfu\na9y9APgnMDDmmCotW/6PoTjTS3Gml+KsfZmSCNoBHyf9vS4qExGRGpYpiUCLCImIxCQjFp0zs77A\neHfvH/09Dihy91uTzok/UBGRLJM1q4+aWQPgXeBU4BNgEXCBu78Ta2AiIjmglO0Mape7F5rZFcCz\nQH1gopKAiEjtyIgWgYiIxCdTBotLlQ03mpnZfWa2ycyWxx1LWcysg5nNM7O3zewtM7sq7phSMbP9\nzGyhmS2N4hwfd0ylMbP6ZrbEzJ6KO5bSmNkaM1sWxbko7nhKY2YtzGyamb1jZiuiscOMYmZHRf+O\nice2DP7v6D+j/36Wm9nDZta41HMzuUUQ3Wj2LnAasB54nQwcOzCzk4EdwCR3PzbueEpjZocAh7j7\nUjNrBiwGzs60f08AM9vf3XdG40evAFe7+8K44yrJzP4L6A00d/cBcceTipl9CPR298/ijqUsZvYA\n8KK73xf9797U3bfFHVdpzKweoV463t0/Lu/82mRm7YCXgW7u/o2ZPQrMcvcHUp2f6S2CrLjRzN1f\nBj6PO47yuPtGd18aHe8A3gHaxhtVau6+MzpsBDQEimIMJyUzaw/8CLgXKHdmRswyOj4zOxA42d3v\ngzBumMlJIHIa8H6mJYEkDYD9o6S6PyFppZTpiUA3mtUQM+sM9AQy7lc2hF9bZrYU2ATMcffX444p\nhTuA68nAJFWCA8+Z2b/M7JK4gylFF2Czmf3DzN4ws7+b2f5xB1WOocDDcQeRiruvB24HPiLMxNzq\n7s+Vdn6mJ4LM7bfKYlG30DRCd8uOuONJxd2L3L0H0B7oY2bHxB1TMjM7C/i3uy8hw39tAye5e0/g\nTODyqCsz0zQAegF3uXsv4EtgbLwhlc7MGgE/AabGHUsqZnYQMADoTGj1NzOzn5Z2fqYngvVAh6S/\nOxBaBVJFZtYQmA486O4z4o6nPFH3wDzCgoSZ5ERgQNT//gjQz8wmxRxTSu6+IXreDDxO6HLNNOuA\ndUktv2mExJCpzgQWR/+mmeg04EN33+LuhcBjhP/PppTpieBfQFcz6xxl4CHAkzHHlLXMzICJwAp3\n/1Pc8ZTGzFqbWYvouAlwOmE8I2O4+w3u3sHduxC6CF5w9+Fxx1WSme1vZs2j46bAGUDGzW5z943A\nx2Z2ZFR0GvB2jCGV5wLCD4BMtRboa2ZNov/uTwNWlHZyRtxQVppsudHMzB4BfgC0MrOPgV+7+z9i\nDiuVk4CLgGVmtiQqG+fus2OMKZVDgQeiWWP1gEfdfVbMMZUnU7sx2wCPh7qABsBD7j4n3pBKdSXw\nUPSj733gZzHHk1KUUE8DMnW8BXdfZGbTgDeAwuj5ntLOz+jpoyIiUvMyvWtIRERqmBKBiEiOUyIQ\nEclxSgQiIjlOiUBEJMcpEYiI5DglAhGRHKdEICKS4/4/QGQBkLhV900AAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x1102cbe90>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"print \"Final Train cost: {}, on Epoch {}\".format(train_score[-1],k)\n", | |
"print \"Final Validation cost: {}, on Epoch {}\".format(val_score[-1],k)\n", | |
"plt.plot(train_score, 'r-', val_score, 'b-')\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"##This part generates a new validation set to test against\n", | |
"val_score_v =[]\n", | |
"num_epochs=1\n", | |
"\n", | |
"for k in range(num_epochs):\n", | |
"\n", | |
" #Generate Data for each epoch\n", | |
" tempX,y = gen_data(5,seq_len,batch_size)\n", | |
" X = []\n", | |
" for i in range(seq_len):\n", | |
" X.append(tempX[:,i,:])\n", | |
"\n", | |
" val_dict = {inputs[i]:X[i] for i in range(seq_len)}\n", | |
" val_dict.update({result: y})\n", | |
" outv, c_val = sess.run([outputs3,cost],feed_dict = val_dict ) \n", | |
" val_score_v.append([c_val])\n", | |
"#print \"Validation cost: {}, on Epoch {}\".format(c_val,k)\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([[8],\n", | |
" [2],\n", | |
" [8],\n", | |
" [8],\n", | |
" [9],\n", | |
" [6],\n", | |
" [0],\n", | |
" [0],\n", | |
" [0],\n", | |
" [0],\n", | |
" [0],\n", | |
" [0],\n", | |
" [0],\n", | |
" [0],\n", | |
" [0]]), 41.0)" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"##Target\n", | |
"tempX[3],y[3]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 44.25109482], dtype=float32)" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"#Prediction\n", | |
"outv[3]" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
I found the documentation deep in the tensorflow ops code
It explains how the decoder operates. Does this help?
Got it. Thank you so much 👍
Hi Rajiv,
Can I ask what the purpose of the dropout layer is in a problem such as this? When training for something like addition don't we need to know all of the inputs?
Thanks,
Jack
Hmm, its a good question. This was one of my first RNNs and I just grabbed code from other projects. I am thinking that it would work like dropout generally, it would help against overfitting and get a better sense of how addition works. If you have the time, I would be curious if you played around with the dropout whether it works like that.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Rajiv, thank you so much for your example.
May I know the reason for using seq2seq.rnn_decoder() in your code? I've tried search on the web and many of the examples are related to language translation. And I really cannot find documentation which talks about the TensorFlow seq2seq class.
Thanks a lot.