Forked from rajshah4/RNN_Simple_Addition_Public.ipynb
Created
September 4, 2016 13:12
-
-
Save vgoklani/1d8e6923b0fc1a0dc07579f2bbaecf53 to your computer and use it in GitHub Desktop.
RNN_Addition_1stgrade
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "**Teaching a computer to add (using memorization)**\n", | |
| "The goal here is to take advantage of Recurrent Neural Networks, for more background see my blog post at http://projects.rajivshah.com/blog/2016/04/05/rnn_addition/ This code was partially derived from https://github.com/yankev/tensorflow_example" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "#Import basic libraries\n", | |
| "import numpy as np\n", | |
| "import tensorflow as tf\n", | |
| "#from tensorflow.models.rnn import rnn_cell\n", | |
| "#from tensorflow.models.rnn import rnn\n", | |
| "#from tensorflow.models.rnn import seq2seq\n", | |
| "from numpy import sum\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "%matplotlib inline " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "#Defining some hyper-params\n", | |
| "num_units = 50 #this is the parameter for input_size in the basic LSTM cell\n", | |
| "input_size = 1 \n", | |
| "batch_size = 50 \n", | |
| "seq_len = 15\n", | |
| "drop_out = 0.6 " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": { | |
| "collapsed": true | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "#Creates our random sequences\n", | |
| "def gen_data(min_length=5, max_length=15, n_batch=50):\n", | |
| "\n", | |
| " X = np.concatenate([np.random.randint(10,size=(n_batch, max_length, 1))],\n", | |
| " axis=-1)\n", | |
| " y = np.zeros((n_batch,))\n", | |
| " # Compute masks and correct values\n", | |
| " for n in range(n_batch):\n", | |
| " # Randomly choose the sequence length\n", | |
| " length = np.random.randint(min_length, max_length)\n", | |
| " X[n, length:, 0] = 0\n", | |
| " # Sum the dimensions of X to get the target value\n", | |
| " y[n] = np.sum(X[n, :, 0]*1)\n", | |
| " return (X,y)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "### Model Construction\n", | |
| "num_layers = 2\n", | |
| "cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)\n", | |
| "cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)\n", | |
| "cell = tf.nn.rnn_cell.DropoutWrapper(cell,output_keep_prob=drop_out)\n", | |
| "\n", | |
| "#create placeholders for X and y\n", | |
| "inputs = [tf.placeholder(tf.float32,shape=[batch_size,1]) for _ in range(seq_len)]\n", | |
| "result = tf.placeholder(tf.float32, shape=[batch_size])\n", | |
| "initial_state = cell.zero_state(batch_size, tf.float32)\n", | |
| "\n", | |
| "outputs, states = tf.nn.seq2seq.rnn_decoder(inputs, initial_state, cell, scope ='rnnln')\n", | |
| "outputs2 = outputs[-1]\n", | |
| "\n", | |
| "W_o = tf.Variable(tf.random_normal([num_units,input_size], stddev=0.01)) \n", | |
| "b_o = tf.Variable(tf.random_normal([input_size], stddev=0.01))\n", | |
| "\n", | |
| "outputs3 = tf.matmul(outputs2, W_o) + b_o\n", | |
| "\n", | |
| "cost = tf.pow(tf.sub(tf.reshape(outputs3, [-1]), result),2)\n", | |
| "\n", | |
| "train_op = tf.train.RMSPropOptimizer(0.005, 0.2).minimize(cost) \n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "### Generate Validation Data\n", | |
| "tempX,y_val = gen_data(5,seq_len,batch_size)\n", | |
| "X_val = []\n", | |
| "for i in range(seq_len):\n", | |
| " X_val.append(tempX[:,i,:])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "source": [ | |
| "##Run this cell to see what the inputs look like \n", | |
| "print (tempX[1]) \n", | |
| "print (y_val[1])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "##Session\n", | |
| "sess = tf.Session()\n", | |
| "sess.run(tf.initialize_all_variables())\n", | |
| "train_score =[]\n", | |
| "val_score= []\n", | |
| "x_axis=[]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "num_epochs=1000\n", | |
| " \n", | |
| "for k in range(1,num_epochs):\n", | |
| "\n", | |
| " #Generate Data for each epoch\n", | |
| " tempX,y = gen_data(5,seq_len,batch_size)\n", | |
| " X = []\n", | |
| " for i in range(seq_len):\n", | |
| " X.append(tempX[:,i,:])\n", | |
| "\n", | |
| " #Create the dictionary of inputs to feed into sess.run\n", | |
| " temp_dict = {inputs[i]:X[i] for i in range(seq_len)}\n", | |
| " temp_dict.update({result: y})\n", | |
| "\n", | |
| " _,c_train = sess.run([train_op,cost],feed_dict=temp_dict) #perform an update on the parameters\n", | |
| "\n", | |
| " val_dict = {inputs[i]:X_val[i] for i in range(seq_len)} #create validation dictionary\n", | |
| " val_dict.update({result: y_val})\n", | |
| " c_val = sess.run([cost],feed_dict = val_dict ) #compute the cost on the validation set\n", | |
| " if (k%100==0):\n", | |
| " train_score.append(sum(c_train))\n", | |
| " val_score.append(sum(c_val))\n", | |
| " x_axis.append(k)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Final Train cost: 3086.54125977, on Epoch 999\n", | |
| "Final Validation cost: 2445.63671875, on Epoch 999\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEACAYAAAC+gnFaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3XmYVNW19/HvYhQBRcCgzKioYEwYoqBG08HhYmJABQWN\nwBPQKDjeq+aCyY2Ywau+McbcRIwRo+AQGRQVEUGlHRkMgqCI4AAKAkEUEFHpptf7xz5FF031XN2n\nquv3eZ566vSuc6pWkbhX7eHsbe6OiIjkrnpxByAiIvFSIhARyXFKBCIiOU6JQEQkxykRiIjkOCUC\nEZEcV2YiMLP9zGyhmS01s7fMbHxUPt7M1pnZkuhxZtI148xstZmtNLMzksp7m9ny6LU7k8obm9mj\nUfkCM+tUA99TRERKUWYicPevgR+6ew+gB9DfzPoADvzR3XtGj2cAzKw7MAToDvQH7jIzi95uAjDK\n3bsCXc2sf1Q+CtgSld8B3JrerygiImUpt2vI3XdGh42AhoQkAGApTh8IPOLuBe6+BngP6GNmhwLN\n3X1RdN4k4OzoeADwQHQ8HTi1sl9CRESqrtxEYGb1zGwpsAmYk1SZX2lmb5rZRDNrEZW1BdYlXb4O\naJeifH1UTvT8MYC7FwLbzKxlVb+QiIhUTkVaBEVR11B7wq/7YwjdPF0I3UUbgNtrNEoREakxDSp6\nortvM7N5QH9331Pxm9m9wFPRn+uBDkmXtSe0BNZHxyXLE9d0BD4xswbAge7+WcnPNzMtiiQiUknu\nnqobfy/lzRpqnej2MbMmwOnAO2Z2SNJp5wDLo+MngaFm1sjMugBdgUXuvhHYbmZ9osHjYcATSdeM\niI4HA8+X8YUy+nHjjTfGHoPiVJyKU3EmHhVVXovgUOABM6tPSBqPuvssM5tkZj0IA8cfApdGFfUK\nM5sCrAAKgTFeHM0Y4H6gCTDL3WdH5ROByWa2GtgCDK1w9CIiUm1lJgJ3Xw70SlE+vIxrbgZuTlG+\nGDg2Rfk3wPkVCVZERNJPdxanUV5eXtwhVIjiTC/FmV6Ks/ZZZfqR4mRmni2xiohkAjPDqztYLCIi\ndZ8SgYhIjlMiEBHJcUoEIiI5TolARCTHKRGIiOQ4JQIRkRynRCAikuOUCEREcpwSgYhIjlMiEBHJ\ncUoEIiI5TolARCTHKRGIiOQ4JQIRkRynRCAikuOUCEREcpwSgYhIjsuuRLBjR9wRiIjUOdmVCGbN\nijsCEZE6J7sSwfTpcUcgIlLnlJkIzGw/M1toZkvN7C0zGx+VtzSzuWa2yszmmFmLpGvGmdlqM1tp\nZmcklfc2s+XRa3cmlTc2s0ej8gVm1qnUgJ59Fr76qjrfV0RESigzEbj718AP3b0H0APob2Z9gLHA\nXHc/Eng++hsz6w4MAboD/YG7zMyit5sAjHL3rkBXM+sflY8CtkTldwC3lhpQr14hGYiISNqU2zXk\n7jujw0ZAQ8CBAcADUfkDwNnR8UDgEXcvcPc1wHtAHzM7FGju7oui8yYlXZP8XtOBU0sNZtAgdQ+J\niKRZuYnAzOqZ2VJgEzAnqszbuPum6JRNQJvouC2wLunydUC7FOXro3Ki548B3L0Q2GZmLVMGc845\nMHMmfPNNBb6aiIhURIPyTnD3IqCHmR0IPG5m3y7xupuZ11SAycbfcw80bw4jR5J3ySXk5eXVxseK\niGSF/Px88vPzK31duYkgwd23mdk84D+ATWZ2iLtvjLp9/h2dth7okHRZe0JLYH10XLI8cU1H4BMz\nawAc6O6fpYrhV78aT4MWLWDZMlASEBHZS15e3l4/kG+66aYKXVferKHWiRlBZtYEOB14B3gSGBGd\nNgKYER0/CQw1s0Zm1gXoCixy943AdjPrEw0eDwOeSLom8V6DCYPPKc2cCZx7Ljz5JBQUVOgLiohI\n2cobIzgUeMHM3gQWEcYIZgG3AKeb2SqgX/Q37r4CmAKsAJ4Bxrh7ottoDHAvsBp4z91nR+UTgVZm\nthq4hmgGUioTJgAdO8Jhh8GLL1b6y4qIyL6suJ7ObGbmrVs78+fDEY/dBh98AHffHXdYIiIZy8xw\ndyv3vGxKBNdd55jBbZe+DyeeCJ98AvXrxx2aiEhGqmgiyKolJi69FO6/H75udzi0bQuvvBJ3SCIi\nWS+rEsERR0DPnjBtGrq5TEQkTbIqEQCMHh0NGicSQVFR3CGJiGS1rEsEZ50Fa9fCsoJu0KIFLFwY\nd0giIlkt6xJBgwZwySUlWgUiIlJlWTVrKBHr+vXw7W/DR7PeovkFZ8GHH4KVOzAuIpJT6uSsoYR2\n7aBfP3ho6THQsCG88UbcIYmIZK2sTAQQDRrfbfigwdE0IhERqYqsTQT9+oXNyuZ3HR7GCbKki0tE\nJNNkbSKoVw8uuwwm5B8Nu3bBW2/FHZKISFbKysHihC1b4PDD4b0Lf03rgw0quOSqiEguqNODxQmt\nWsHAgfAPG6lppCIiVZTViQDCoPHfnu1E0WdbYeXKuMMREck6WZ8I+vSB5s2N544bp1aBiEgVZH0i\nMIumkn42RNNIRUSqIKsHixN27ICOHZ1l9XrSftFjYQczEZEclxODxQnNmsGFFxp/7/Q7dQ+JiFRS\nnWgRQLiN4D/yvmZNl340fP21WoxMRCQz5VSLAMIidIcd3YgnVx4JH30UdzgiIlmjziQCgNFj6jGh\n+S/gscfiDkVEJGvUqUQwaBAs/+pwVj24KO5QRESyRp1KBI0bw88urs/f3joRNmyIOxwRkaxQZiIw\nsw5mNs/M3jazt8zsqqh8vJmtM7Ml0ePMpGvGmdlqM1tpZmcklfc2s+XRa3cmlTc2s0ej8gVm1qk6\nX+jSMQ2YxHC+evTJ6ryNiEjOKK9FUAD8p7sfA/QFLjezboADf3T3ntHjGQAz6w4MAboD/YG7zPZs\nHTYBGOXuXYGuZtY/Kh8FbInK7wBurc4X6tIFjjtmJ1Pu2VqdtxERyRllJgJ33+juS6PjHcA7QLvo\n5VRTkgYCj7h7gbuvAd4D+pjZoUBzd0903k8Czo6OBwAPRMfTgVOr+F32GD3uICas6gebN1f3rURE\n6rwKjxGYWWegJ7AgKrrSzN40s4lm1iIqawusS7psHSFxlCxfT3FCaQd8DODuhcA2M2tZua+xtx+d\n05hPGnVhyf+9Up23ERHJCQ0qcpKZNQOmAVe7+w4zmwD8Jnr5t8DthC6eGjV+/Pg9x3l5eeTl5aU8\nr359+PlPNjDhH4255zcpTxERqXPy8/PJz8+v9HXl3llsZg2BmcAz7v6nFK93Bp5y92PNbCyAu98S\nvTYbuBFYC8xz925R+QXAKe4+OjpnvLsvMLMGwAZ3PzjF55R5Z3FJG9//km5H7GLNGjiw00EVvk5E\npK5Iy53F0UDvRGBFchKI+vwTzgGWR8dPAkPNrJGZdQG6AovcfSOw3cz6RO85DHgi6ZoR0fFg4Ply\nv10FHHJ4U05vu4IH/2dVOt5ORKTOKq9r6CTgImCZmS2Jym4ALjCzHoTZQx8ClwK4+wozmwKsAAqB\nMUk/48cA9wNNgFnuPjsqnwhMNrPVwBZgaDq+GMDoETu58i9HMsbDctUiIrKvOrPoXCr++Va6t97E\nPU+35+T+TWsoMhGRzJRzi86lYge14LKj8pnwW00jFREpTZ1OBADDRzdl1usH8+9/xx2JiEhmqvOJ\n4KALz+RcHuO+u3fFHYqISEaq84mAVq0Y3WMBf/trAbt3xx2MiEjmqfuJADhu5LG0KtzEnDlxRyIi\nknlyIhFwzjmM/uoOJvxVTQIRkZJyIxG0acPQXqt49aXd2sVSRKSE3EgEQNMhZ3FRu3zuuSfuSERE\nMkudvqFsL+vXs6L7YE7d/zXWrjUaNUpfbCIimUg3lJXUrh3du8NRB3/OjBlxByMikjlyJxEADB7M\n6NZTmTAh7kBERDJH7nQNAXz4IbuOO4mODdaTn28cfXR6YhMRyUTqGkqlSxcadW7LqFPXcvfdcQcj\nIpIZcisRAAwaxM/rT+TBB2HnzriDERGJX04mgk5z7+WEvs4//xl3MCIi8cu9RHDkkXDwwVx2ygoN\nGouIkIuJAGDQIPqvn8jmzfCvf8UdjIhIvHIzEQweTP3Hp3Hpz12tAhHJebmZCLp3h/33Z1SvJTz2\nGGzdGndAIiLxyc1EYAaDB/OtF/5J//4waVLcAYmIxCc3EwHAoEEwfTqjL3Puvhuy5L46EZG0y91E\n0KMHACc3X4oZvPhizPGIiMQkdxOBGQwahD02ncsuQ4PGIpKzykwEZtbBzOaZ2dtm9paZXRWVtzSz\nuWa2yszmmFmLpGvGmdlqM1tpZmcklfc2s+XRa3cmlTc2s0ej8gVm1qkmvmhKgwbBtGkMH+bMmQMb\nN9baJ4uIZIzyWgQFwH+6+zFAX+ByM+sGjAXmuvuRwPPR35hZd2AI0B3oD9xlZokFjyYAo9y9K9DV\nzPpH5aOALVH5HcCtaft25Tn+eNi5kwPXr2DwYJg4sdY+WUQkY5SZCNx9o7svjY53AO8A7YABwAPR\naQ8AZ0fHA4FH3L3A3dcA7wF9zOxQoLm7L4rOm5R0TfJ7TQdOre6XqjAzOPfcMGg8Gu65B3ZrW2MR\nyTEVHiMws85AT2Ah0MbdN0UvbQLaRMdtgXVJl60jJI6S5eujcqLnjwHcvRDYZmYtK/MlqmXwYJg2\njV694JBD4Jlnau2TRUQyQoOKnGRmzQi/1q929y+Ke3vA3d3MamXy5fjx4/cc5+XlkZeXV/03PfFE\n2LwZVq9m9OiuTJgAZ51V/bcVEalt+fn55OfnV/q6cjemMbOGwEzgGXf/U1S2Eshz941Rt888dz/a\nzMYCuPst0XmzgRuBtdE53aLyC4BT3H10dM54d19gZg2ADe5+cIo4qr8xTWnGjIGOHfnq6rF06ACv\nvw5dutTMR4mI1Ja0bEwTDfROBFYkkkDkSWBEdDwCmJFUPtTMGplZF6ArsMjdNwLbzaxP9J7DgCdS\nvNdgwuBz7Yq6h5o0gWHDwliBiEiuKLNFYGbfB14ClgGJE8cBi4ApQEdgDXC+u2+NrrkBGAkUErqS\nno3KewP3A02AWe6emIraGJhMGH/YAgyNBppLxlJzLYLCQjj0UHj9dd79pjOnnAIffQSNG9fMx4mI\n1IaKtghya8/islxyCRx9NFx7LaeeChdfDBdcUHMfJyJS07RncWVFaw8BjB6tO41FJHeoRZCwa1eY\nP7psGQVt2tOpE8ydC8ccU3MfKSJSk9QiqKxGjeAnP4HHH6dhw9A1dPfdcQclIlLzlAiSJXUPXXIJ\nPPQQ7NgRc0wiIjVMiSDZGWfA0qWwaRMdOsDJJ8Mjj8QdlIhIzVIiSLbffnDmmTAj3BaRGDTOkmEU\nEZEqUSIoKbq5DEIDYds2WLSonGtERLKYZg2V9OWX0LYtfPABtGrFbbfBihVw//01/9EiIumkWUNV\n1bQpnH46PBFWwPjZz0JP0WefxRyXiEgNUSJIJWn20MEHh9VIH3ignGtERLKUuoZS2b4d2reHjz+G\nAw/k1Vdh5EhYuTLsZSMikg3UNVQdBxwAeXnw1FNA2LKgUSN44YV4wxIRqQlKBKVJ6h4y0/pDIlJ3\nqWuoNJ9/Dp06wSefQLNmbN8e/nz77TCpSNJv1y5YsgRefhleeSWs8/T738cdlUj2UtdQdR10UOgT\nmjULCL1FQ4bAvffGHFcdsn07zJkD//M/8MMfQsuWcOmlsGZNuJ3j3nvhrbfijlKk7lOLoCx//zs8\n9xw8+igAb74ZZhB9+CE0qNBuz5Lsk0/CL/3EY9Uq6N0bvv/98DjhBGjRovj8O+8MieLpp+OLWSSb\naWOadNi8GY44AjZuhCZNgNBI+O//hoEDazeUbOMO775b3M3zyiuwdSucdFJxxd+7d9m7wO3aBd26\nhXzcr1/txS5SVygRpEu/fnDVVXD22QBMnhxWJZ09u/ZDyWQl+/dffRWaNSuu9L///VCp16tkZ+SU\nKXDrrfD665W/ViTXKRGky113wWuvwYMPAvD119ChAyxYAIcfXvvhZIrt28O/QaLif/310HhKrvjb\nt6/+57hD375w9dVw4YXVfz+RXKJEkC4bNkD37qF7KOrHuO668Ov0tttqP5y4bNgQKvxExV9e/346\nvfQSDB8ebujbb7+a+QyRukiJIJ1OPhnGjoUf/xiA1atDX/dHH9XNiinRv59c8Ve2fz/dBg6EU06B\na6+tvc8UyXZKBOn0pz/BsmVw3317is44I/xKveiieEJKp0T/fqLiT1f/fjqtXBny8bvvhmmmIlI+\nJYJ0+ugj6NUr9I80bAjA44/D7beHyjPbJPr3ExV/TfXvp9tll4UE9Yc/xB2JSHZIWyIws/uAHwP/\ndvdjo7LxwMXA5ui0G9z9mei1ccBIYDdwlbvPicp7A/cD+wGz3P3qqLwxMAnoBWwBhrj72hRxxJcI\nAPr0gd/9LixRDRQWQufO8MwzcOyx8YVVHvfQn79gAcyfH57fe6/2+vfTaePGcLfx4sXh315EypbO\nRHAysAOYlJQIbgS+cPc/lji3O/AwcBzQDngO6OrubmaLgCvcfZGZzQL+7O6zzWwM8G13H2NmQ4Bz\n3H1oijjiTQS33RY2q7n77j1FN90EmzaFiUWZYvv2sKNaotJfsCD8ij7hhPDo2xd69Kjd/v10uumm\nkNgeeijuSEQyX1q7hsysM/BUiUSww91vL3HeOKDI3W+N/p4NjAfWAi+4e7eofCiQ5+6XRefc6O4L\nzawBsMHdD04RQ7yJ4P33w91kn3wC9esDsH59aA2sXQvNm9d+SEVFoc88UenPnx/ueu7Zs7jS79u3\nbq2NtGMHHHlkWBi2d++4oxHJbBVNBNVZKOFKMxsO/Au41t23Am2BBUnnrCO0DAqi44T1UTnR88cA\n7l5oZtvMrKW7Z9aeYIcfHmrUV16BH/wAgHbtwmrVDz0U+q9r2tatsHBhcaW/cGFYEilR6f/85/Cd\n74Qls+uqZs3gxhvh+uvh+ee1P4RIOlQ1EUwAfhMd/xa4HRiVlojKMH78+D3HeXl55OXl1fRH7i2x\nsX2UCCAsT33ddWGxtHRWSkVFYa/k5L79jz4Kv4L79g2fe//9cMgh6fvMbDFqVJjI9cwz8KMfxR2N\nSObIz88nPz+/0tdVqWuotNfMbCyAu98SvTYbuJHQNTQvqWvoAuAUdx+d6D5y9wUZ3TUEYQ7jqaeG\nncuiuZRFRXDUUWEryxNPrPpbf/ZZ+IWfqPQXLQrbZPbtW/yL/zvf0WJ3CU8+CTfcEBYCjHrqRKSE\nGl2G2swOTfrzHGB5dPwkMNTMGplZF6ArsMjdNwLbzayPmRkwDHgi6ZoR0fFg4PmqxFQrjj46TK9Z\nuHBPUb16oVsoaQy5XLt3h9sS/vY3+NnPwtt27gz/7/+F1668Mty0tnp1WNtozJgwe1VJoNhPfgKt\nWoVWkYhUT0VmDT0C/ABoDWwi/MLPA3oADnwIXOrum6LzbyBMHy0Ernb3Z6PyxPTRJoTpo1dF5Y2B\nyUBPwvTRoe6+JkUc8bcIIHRQ79gRbiKIbNkS5uG/916onEr69NPiGTzz54d5+4ceuvdMnm9/W79s\nK2vRIjj33DBg3rRp3NGIZB7dUFZTli2DAQPC9JykQYERI8IMomuuCZupzJ9f3M2zaRMcf3xxpd+n\nT+qEIZU3ZEjoMvvlL+OORCTzKBHUFPcwKPDww/C97+0pXrBgz71mdOhQ3Ld/wglheQb92q8Z778f\nEuuKFfCtb8UdjUhmUSKoSePGhef//d+9ihNLNRx0UAwx5bBrrgl3ev/lL3FHIpJZlAhq0uLFMHRo\nuMVVE9nTr7AwrD194IEVumvs009Dq+vVV8PNZiISaPP6mtSrV6isli8v/1ypmMLCcIfYZZeFO/Wu\nuw7694d33in30tatw+mJhpqIVI4SQVWYhekq06fHHUl2KyyEuXPDLdFt24bNoLt0CTvCvfFG6Hob\nPBi+/LLct7rqqtA199prtRC3SB2jrqGqmj8fLr4Y3n477kiyS0EBzJsHU6fCjBmh4j/vvFDhd+my\n97nu4UaLwsJwQ0U53XCTJoX7OV59VT12IqAxgppXVBSmBz33XOigltIVFIRun2nTQuV/+OHFlX95\n60nv3BmmYF1+eVjHowy7d4chhV//OjTYRHKdEkFtuOqqMGfxV7+KO5LMs2tXqPynToUnngijuOed\nB4MGQadOlXuvVavCPpmzZ5c7eDxnDlxxRWioRXsIieQsJYLa8OKLYe7ikiVxR5IZdu0KLaSpU8Ni\nQEcdVVz5d+xYvfeeOjWMISxeXO783P/4j3DP3+WXV+8jRbKdEkFt2L07DHK+9lro7shF33wTBnyn\nTg2bBHTrVlz5d+iQ3s+65pqwOdCMGWVuoLx0aZhwtGoVHHBAekMQySZKBLXlssvgsMPgF7+IO5La\n8803oQ9m6lSYOTPsH5mo/Nu1K//6qtq1KywBfs455f57jxgR8tDvfldz4YhkOiWC2jJ3bhgjSFqR\ntE76+mt49tkw4DtzZlhYKVH51+YWaB9/DMcdB48+ute+EKlO69EjLA1Vk7lJJJMpEdSWgoKwlOgb\nb1S/HzzTfP11GKCdOhVmzYLvfjdU/ueeG75zXJ59FkaODOMFZezMM3YsbN4MEyfWYmwiGUSJoDaN\nHBmWwLzmmrgjqb6vvtq78u/Zs7jyz6Tt0MaPD4P1c+eWulHDtm1hstJzz4UGjEiuUSKoTbNmwc03\nh/2Ms9HOnWHfx2nTwnOvXsWVf5s2cUeX2u7dcOaZYQXYm28u9bQ77wzDGU8/XYuxiWQIJYLa9M03\n4dfyihXxdplUxs6dIYFNnRq6Wr73vVD5n3NO9qznvHlzSFoTJsBZZ6U8ZdeuMJHp73+Hfv1qOT6R\nmCkR1LaLLgqbFo8ZUzPvX1QUEs4334S++6+/rthxqtfWrAk/k48/vrjyP3ifbaKzw6uvhpbLggX7\nLlERmTIFbr01rEVUxqxTkTpHiaC2zZgBv/89XHtt9Srp0o537YLGjWG//Yqfq3rcpg38+Mdh2c66\n4I474KGHQlJo3Hifl93D5jVXXw0//WkM8YnERImgtn31Vdh1fseOsivjqlbijRppJbXSuId1i9q0\ngbvuSnnKSy/B8OGwcmX45xTJBUoEklu2bQvjHDfdBBdemPKUgQPhlFNCo00kFygRSO5580047bQw\nrbR7931efuedkAjefRdatowhPpFaph3KJPd897tw222hm2jHjn1e7tYt3AhdxmxTkZykFoHUPaNG\nhTGbhx7aZ1xl48awNNLixeVvhSCS7dLWIjCz+8xsk5ktTypraWZzzWyVmc0xsxZJr40zs9VmttLM\nzkgq721my6PX7kwqb2xmj0blC8yskovVi5Twl7+EDQnuvnuflw45JIzp//KXMcQlkqEq0jX0D6B/\nibKxwFx3PxJ4PvobM+sODAG6R9fcZbbnJ9kEYJS7dwW6mlniPUcBW6LyO4Bbq/F9RKBJk3CX9I03\nhpsHSrjuurBb5uLFMcQmkoHKTQTu/jLweYniAcAD0fEDwNnR8UDgEXcvcPc1wHtAHzM7FGju7oui\n8yYlXZP8XtOBU6vwPUT21rVraBGcfz589tleLzVrFnLE9deHmaciua6qg8Vt3H1TdLwJSCxI0xZY\nl3TeOqBdivL1UTnR88cA7l4IbDMzzemQ6jv33PAYPjzcmZ1k1CjYsCEsrSSS61Iv21gJ7u5mViu/\nq8aPH7/nOC8vj7y8vNr4WMlmt9wCeXlhjYlx4/YUN2gQin7xi7C1Zf368YUoki75+fnk5+dX+roK\nzRoys87AU+5+bPT3SiDP3TdG3T7z3P1oMxsL4O63ROfNBm4E1kbndIvKLwBOcffR0Tnj3X2BmTUA\nNrj7PgvfaNaQVNm6dWEzm4cfhh/+cE+xe9jbZsSI0EIQqWtq+j6CJ4ER0fEIYEZS+VAza2RmXYCu\nwCJ33whsN7M+0eDxMOCJFO81mDD4LJI+7dvDpElhoaENG/YUm8Ef/hDGC778Msb4RGJWbovAzB4B\nfgC0JowH/JpQiU8BOgJrgPPdfWt0/g3ASKAQuNrdn43KewP3A02AWe5+VVTeGJgM9AS2AEOjgeaS\ncahFINXzm9/A88+HR9JmNkOGhI1rfvWrGGMTqQFaYkKkpKIi+NGPwmbGt9yyp/j998PqpCtWZM9W\nDCIVoUQgksqnn0Lv3vB//wcDBuwpvuYaKCwM96KJ1BVKBCKlWbAgLEU6fz4cdhgQ8kO3bmFLgyOP\njDk+kTTRonMipenbN6wxcd55YeMfwh4911231wxTkZyhFoHkJvcwStyy5Z41ib76Co46Ch55BE46\nKeb4RNJALQKRspjBvfeGRYcmTwbCEkW//a2WnpDco0QgueuAA8LidP/1X/DWWwBcdBHs3AmPPx5z\nbCK1SIlActuxx8Ltt4fNbL74gvr1w942Y8dCQUHcwYnUDiUCkeHDwx6WF18M7pxxRti05p574g5M\npHZosFgEwuyhE0+EkSPhiitYuhT694dVq0IPkkg20n0EIpX1/vtwwgkwcyYcfzwjRkCHDvC738Ud\nmEjVKBGIVMWMGeE248WL+XhnK3r0gGXLoF278i8VyTRKBCJVdf31Yc/jmTMZe0M9Nm+GiRPjDkqk\n8pQIRKqqoAD69YP+/dl6+S856ih47rkwwUgkmygRiFTHJ5/A974Hkydz51unMmcOPP103EGJVI7u\nLBapjrZtwx3Hw4YxesB6Vq6EF16IOyiRmqEWgUhZfv97mD2bRy+bx21/bMDrr0M9/XySLKEWgUg6\njBsHzZtz/pJx1K8fFqQTqWvUIhApz5Yt0Ls3L108ieH3nsLKlbDffnEHJVI+tQhE0qVVK5gyhVP+\nPJjvHvElf/1r3AGJpJdaBCIV9de/8s5fnueUT6fz7rtGy5ZxByRSNrUIRNJtzBi69WjMua1f5uab\n4w5GJH3UIhCpjC++YGOvH3HMhrn8a/l+dOkSd0AipVOLQKQmNG/OITPu5sqiP/OrK7bGHY1IWlQr\nEZjZGjNbZmZLzGxRVNbSzOaa2Sozm2NmLZLOH2dmq81spZmdkVTe28yWR6/dWZ2YRGrcMcdw3Z87\nMu/ZXSwR0C9QAAAKxElEQVR+cUfc0YhUW3VbBA7kuXtPdz8+KhsLzHX3I4Hno78xs+7AEKA70B+4\ny8wSTZYJwCh37wp0NbP+1YxLpEY1u3goN544l+vPW4MXqctSsls6uoZK9j8NAB6Ijh8Azo6OBwKP\nuHuBu68B3gP6mNmhQHN3XxSdNynpGpGMNWrWIDZsb8pN332M+XctYeeOorhDEqmSdLQInjOzf5nZ\nJVFZG3ffFB1vAtpEx22BdUnXrgPapShfH5WLZLQGzfbj4Vkt+GT/I7ji2sa0PuAbvn3wJoYP3Mqd\nd8LLL8MXX8QdpUj5GlTz+pPcfYOZHQzMNbOVyS+6u5tZ2trN48eP33Ocl5dHXl5eut5apEp69juI\nexYeBO7sWryct+98jsVPfcIb8/vwyP7fZ/m/D6FDR6NXL+jVC3r3hp49oUWL8t9bpLLy8/PJz8+v\n9HVpmz5qZjcCO4BLCOMGG6Nun3nufrSZjQVw91ui82cDNwJro3O6ReUXAD9w98tKvL+mj0p22L07\nLFX64IMUzHiald85nzeOGcbiet/jjWUNefNNaNOGfZJD69ZxBy51TY3vR2Bm+wP13f0LM2sKzAFu\nAk4Dtrj7rVHl38Ldx0aDxQ8DxxO6fp4DjohaDQuBq4BFwNPAn919donPUyKQ7PPll2H7ywcfhAUL\nYMAAdl84jFXtfsgbb9bnjTdg8WJYsiS0Enr33jtBtGlT/keIlKY2EkEX4PHozwbAQ+7+v2bWEpgC\ndATWAOe7+9bomhuAkUAhcLW7PxuV9wbuB5oAs9z9qhSfp0Qg2W3jRvjnP8M+Bxs3woUXwrBh8J3v\nUFQEH3wQkkIiObzxBjRpUpwUEs9t24KV+5925isshM8/D2v6JR7btkGfPnDkkXFHVzdohzKRTLZi\nRWglPPhgaAoMGxYSQ7vieRLusHZtcVJIJAizvZNDr17QqVO8yWHnzr0r9MTj009Tl2/ZEgbSDzww\nrOnXqlXoGmvaFF56Cb71LTjvvPBQUqg6JQKRbFBUFKYXTZ4Mjz0WavWLLoJBg6B5831Od4f16/du\nNSxeDN98s29yOPzwyieHoqJ9f6VX5AHFFXpFHy1aQP36+8aweze8+ipMnQrTpoXusfPPD0mha9cq\n/BvnMCUCkWzz1Vcwc2ZICi++CD/+cUgKZ5wBDcqe4LdhQxhnSG49bNsWBqETiaFx4/Ir9K1bQ/4p\nWWm3bl12pb7//jXzT7J7N7zySnFSOPTQ4qRwxBE185l1iRKBSDb79FN49NHQdfTBBzB0aOg+6t27\nwj/zN28OySGRGHbvLv9X+kEHlZtzYpNIClOmwPTpYawk0X2kpJCaEoFIXbF6dfF4QqNGoZXw059C\n585xRxab3btDj9rUqcVJIdFSOPzwuKPLHEoEInWNO8yfHxLClCnQvXtoJZx3Xk7foZZIClOmhGGW\ndu2KWwq5nhSUCETqsl27YNaskBTmzoXTTw9J4cwzQ6shR+3eHWYdJVoK7dsXtxQOOyzu6GqfEoFI\nrvj88zCSOnlymJZ6/vmh++iEE+rGDQdVlEgKiZZChw7FLYU6kxTc4euvw1zcFA8bOlSJQCTnfPgh\nPPxwSAqFhcXjCTk+77KwcO+WQseOxS2FWt9lrrBw30p7+/ZSK/NyHw0ahKleKR42daoSgUjOcg9z\nSSdPDnczd+kCRx+99+slz6/o39W5trz3gtCKqV8f6tVL/VzWa+U9169PodfnxQ86MHXJETy2tAud\nWu3g/OPXcl6fj+h8yNcV+1wIy4dUpeIuKIBmzfatuA84oNQKvcxHw4b7/hvu+adU15CIQKh4Xngh\n3GyQrGS3UWX+rs615b1XUVF47N6973Oqsoo+pygrLIQX1x/BlA+P4/GPetG56aec134+57V7jc77\nbSz92qKicBt0VSrwJk1qrctOiUBEpBIKCyE/P3QfPfZYaEQluo86dYo7utK5hzvLd+wIjZTk51NP\nVSIQEamSRFKYMgUefzwMLicGmquaFNzDzeOpKuzEc1mvlXVuw4ahgdKs2d7P+flKBCIi1VZQUNxS\nSCSFAQPCkh2VqbC//DJc06zZvhV2qufKnFPa3eDqGhIRSbOCApg3D2bPDt38lanImzZNvcheTVIi\nEBHJcRVNBNXdvF5ERLKcEoGISI5TIhARyXFKBCIiOU6JQEQkxykRiIjkOCUCEZEclzGJwMz6m9lK\nM1ttZv8ddzwiIrkiIxKBmdUH/gL0B7oDF5hZt3ijqrz8/Py4Q6gQxZleijO9FGfty4hEABwPvOfu\na9y9APgnMDDmmCotW/6PoTjTS3Gml+KsfZmSCNoBHyf9vS4qExGRGpYpiUCLCImIxCQjFp0zs77A\neHfvH/09Dihy91uTzok/UBGRLJM1q4+aWQPgXeBU4BNgEXCBu78Ta2AiIjmglO0Mape7F5rZFcCz\nQH1gopKAiEjtyIgWgYiIxCdTBotLlQ03mpnZfWa2ycyWxx1LWcysg5nNM7O3zewtM7sq7phSMbP9\nzGyhmS2N4hwfd0ylMbP6ZrbEzJ6KO5bSmNkaM1sWxbko7nhKY2YtzGyamb1jZiuiscOMYmZHRf+O\nice2DP7v6D+j/36Wm9nDZta41HMzuUUQ3Wj2LnAasB54nQwcOzCzk4EdwCR3PzbueEpjZocAh7j7\nUjNrBiwGzs60f08AM9vf3XdG40evAFe7+8K44yrJzP4L6A00d/cBcceTipl9CPR298/ijqUsZvYA\n8KK73xf9797U3bfFHVdpzKweoV463t0/Lu/82mRm7YCXgW7u/o2ZPQrMcvcHUp2f6S2CrLjRzN1f\nBj6PO47yuPtGd18aHe8A3gHaxhtVau6+MzpsBDQEimIMJyUzaw/8CLgXKHdmRswyOj4zOxA42d3v\ngzBumMlJIHIa8H6mJYEkDYD9o6S6PyFppZTpiUA3mtUQM+sM9AQy7lc2hF9bZrYU2ATMcffX444p\nhTuA68nAJFWCA8+Z2b/M7JK4gylFF2Czmf3DzN4ws7+b2f5xB1WOocDDcQeRiruvB24HPiLMxNzq\n7s+Vdn6mJ4LM7bfKYlG30DRCd8uOuONJxd2L3L0H0B7oY2bHxB1TMjM7C/i3uy8hw39tAye5e0/g\nTODyqCsz0zQAegF3uXsv4EtgbLwhlc7MGgE/AabGHUsqZnYQMADoTGj1NzOzn5Z2fqYngvVAh6S/\nOxBaBVJFZtYQmA486O4z4o6nPFH3wDzCgoSZ5ERgQNT//gjQz8wmxRxTSu6+IXreDDxO6HLNNOuA\ndUktv2mExJCpzgQWR/+mmeg04EN33+LuhcBjhP/PppTpieBfQFcz6xxl4CHAkzHHlLXMzICJwAp3\n/1Pc8ZTGzFqbWYvouAlwOmE8I2O4+w3u3sHduxC6CF5w9+Fxx1WSme1vZs2j46bAGUDGzW5z943A\nx2Z2ZFR0GvB2jCGV5wLCD4BMtRboa2ZNov/uTwNWlHZyRtxQVppsudHMzB4BfgC0MrOPgV+7+z9i\nDiuVk4CLgGVmtiQqG+fus2OMKZVDgQeiWWP1gEfdfVbMMZUnU7sx2wCPh7qABsBD7j4n3pBKdSXw\nUPSj733gZzHHk1KUUE8DMnW8BXdfZGbTgDeAwuj5ntLOz+jpoyIiUvMyvWtIRERqmBKBiEiOUyIQ\nEclxSgQiIjlOiUBEJMcpEYiI5DglAhGRHKdEICKS4/4/QGQBkLhV900AAAAASUVORK5CYII=\n", | |
| "text/plain": [ | |
| "<matplotlib.figure.Figure at 0x1102cbe90>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "print \"Final Train cost: {}, on Epoch {}\".format(train_score[-1],k)\n", | |
| "print \"Final Validation cost: {}, on Epoch {}\".format(val_score[-1],k)\n", | |
| "plt.plot(train_score, 'r-', val_score, 'b-')\n", | |
| "plt.show()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "##This part generates a new validation set to test against\n", | |
| "val_score_v =[]\n", | |
| "num_epochs=1\n", | |
| "\n", | |
| "for k in range(num_epochs):\n", | |
| "\n", | |
| " #Generate Data for each epoch\n", | |
| " tempX,y = gen_data(5,seq_len,batch_size)\n", | |
| " X = []\n", | |
| " for i in range(seq_len):\n", | |
| " X.append(tempX[:,i,:])\n", | |
| "\n", | |
| " val_dict = {inputs[i]:X[i] for i in range(seq_len)}\n", | |
| " val_dict.update({result: y})\n", | |
| " outv, c_val = sess.run([outputs3,cost],feed_dict = val_dict ) \n", | |
| " val_score_v.append([c_val])\n", | |
| "#print \"Validation cost: {}, on Epoch {}\".format(c_val,k)\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(array([[8],\n", | |
| " [2],\n", | |
| " [8],\n", | |
| " [8],\n", | |
| " [9],\n", | |
| " [6],\n", | |
| " [0],\n", | |
| " [0],\n", | |
| " [0],\n", | |
| " [0],\n", | |
| " [0],\n", | |
| " [0],\n", | |
| " [0],\n", | |
| " [0],\n", | |
| " [0]]), 41.0)" | |
| ] | |
| }, | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "##Target\n", | |
| "tempX[3],y[3]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": { | |
| "collapsed": false | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "array([ 44.25109482], dtype=float32)" | |
| ] | |
| }, | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "#Prediction\n", | |
| "outv[3]" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 2", | |
| "language": "python", | |
| "name": "python2" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 2 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython2", | |
| "version": "2.7.11" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment