Created
September 1, 2017 16:10
-
-
Save Echooff3/a8ce307c3f007bc3c770461e339fa65f to your computer and use it in GitHub Desktop.
RNN LSTM - The next number from 0 - 9
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"from tensorflow.contrib import rnn\n", | |
"import random\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vocab = [i for i in range(0,10)]\n", | |
"vocab_size = len(vocab)\n", | |
"n_input = 1\n", | |
"# number of units in RNN cell\n", | |
"n_hidden = 64\n", | |
"\n", | |
"# tf Graph input\n", | |
"x = tf.placeholder(\"float\", [None, n_input, 1])\n", | |
"y = tf.placeholder(\"float\", [None, vocab_size])\n", | |
"\n", | |
"# RNN output node weights and biases\n", | |
"weights = {\n", | |
" 'out': tf.Variable(tf.random_normal([n_hidden, vocab_size]),name=\"w1\")\n", | |
"}\n", | |
"biases = {\n", | |
" 'out': tf.Variable(tf.random_normal([vocab_size]),name=\"b1\")\n", | |
"}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def RNN(x, weights, biases):\n", | |
"\n", | |
" # reshape to [1, n_input]\n", | |
" x = tf.reshape(x, [-1, n_input])\n", | |
"\n", | |
" # Generate a n_input-element sequence of inputs\n", | |
" # (eg. [had] [a] [general] -> [20] [6] [33])\n", | |
" x = tf.split(x,n_input,1)\n", | |
"\n", | |
" # 2-layer LSTM, each layer has n_hidden units.\n", | |
" # Average Accuracy= 95.20% at 50k iter\n", | |
" rnn_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(n_hidden),rnn.BasicLSTMCell(n_hidden)])\n", | |
"\n", | |
" # 1-layer LSTM with n_hidden units but with lower accuracy.\n", | |
" # Average Accuracy= 90.60% 50k iter\n", | |
" # Uncomment line below to test but comment out the 2-layer rnn.MultiRNNCell above\n", | |
" # rnn_cell = rnn.BasicLSTMCell(n_hidden)\n", | |
"\n", | |
" # generate prediction\n", | |
" outputs, states = rnn.static_rnn(rnn_cell, x, dtype=tf.float32)\n", | |
"\n", | |
" # there are n_input outputs but\n", | |
" # we only want the last output\n", | |
" return tf.add(tf.matmul(outputs[-1], weights['out']), biases['out'], \"network_out\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pred = RNN(x, weights, biases)\n", | |
"\n", | |
"# Loss and optimizer\n", | |
"learning_rate = 0.001\n", | |
"cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))\n", | |
"optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate).minimize(cost)\n", | |
"\n", | |
"# Model evaluation\n", | |
"correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))\n", | |
"accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n", | |
"\n", | |
"# Initializing the variables\n", | |
"init = tf.global_variables_initializer()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"session = tf.Session()\n", | |
"session.run(init)\n", | |
"step = 0\n", | |
"offset = random.randint(0,n_input+1)\n", | |
"end_offset = n_input + 1\n", | |
"acc_total = 0\n", | |
"loss_total = 0\n", | |
"training_iters = 5000\n", | |
"display_step = 1000" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Iter= 1000, Average Loss= 1.681965, Average Accuracy= 36.20%\n", | |
"[2] - [3] vs [3]\n", | |
"Iter= 2000, Average Loss= 0.598271, Average Accuracy= 81.00%\n", | |
"[5] - [6] vs [7]\n", | |
"Iter= 3000, Average Loss= 0.237745, Average Accuracy= 94.60%\n", | |
"[4] - [5] vs [5]\n", | |
"Iter= 4000, Average Loss= 0.089085, Average Accuracy= 99.30%\n", | |
"[2] - [3] vs [3]\n", | |
"Iter= 5000, Average Loss= 0.027395, Average Accuracy= 99.90%\n", | |
"[8] - [9] vs [9]\n" | |
] | |
} | |
], | |
"source": [ | |
"while step < training_iters:\n", | |
" # training\n", | |
" if offset > (len(vocab)-end_offset):\n", | |
" offset = random.randint(0, n_input+1)\n", | |
"\n", | |
" symbols_in_keys = [[vocab[i]] for i in range(offset, offset+n_input)]\n", | |
" symbols_in_keys = np.reshape(np.array(symbols_in_keys), [-1, n_input, 1])\n", | |
"\n", | |
" symbols_in_keys\n", | |
"\n", | |
" symbols_out_onehot = np.zeros([vocab_size], dtype=float)\n", | |
" symbols_out_onehot[offset+n_input] = 1.0\n", | |
" symbols_out_onehot = np.reshape(symbols_out_onehot,[1,-1])\n", | |
"\n", | |
" symbols_out_onehot\n", | |
"\n", | |
" #move the graph\n", | |
" _, acc, loss, onehot_pred = session.run([optimizer, accuracy, cost, pred], \\\n", | |
" feed_dict={x: symbols_in_keys, y: symbols_out_onehot})\n", | |
" loss_total += loss\n", | |
" acc_total += acc\n", | |
"\n", | |
" #results\n", | |
" if (step+1) % display_step == 0:\n", | |
" print(\"Iter= \" + str(step+1) + \", Average Loss= \" + \\\n", | |
" \"{:.6f}\".format(loss_total/display_step) + \", Average Accuracy= \" + \\\n", | |
" \"{:.2f}%\".format(100*acc_total/display_step))\n", | |
" acc_total = 0\n", | |
" loss_total = 0\n", | |
" symbols_in = [vocab[i] for i in range(offset, offset + n_input)]\n", | |
" symbols_out = vocab[offset + n_input]\n", | |
" symbols_out_pred = vocab[int(tf.argmax(onehot_pred, 1).eval(session=session))]\n", | |
" print(\"%s - [%s] vs [%s]\" % (symbols_in,symbols_out,symbols_out_pred))\n", | |
" offset += (n_input+1)\n", | |
" step += 1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2 -> 3\n" | |
] | |
} | |
], | |
"source": [ | |
"symbols_in = vocab[random.randint(0,n_input+1)]\n", | |
"keys = np.reshape(np.array([symbols_in]), [-1, n_input, 1])\n", | |
"onehot_pred = session.run(pred, feed_dict={x: keys})\n", | |
"onehot_pred_index = int(tf.argmax(onehot_pred, 1).eval(session=session))\n", | |
"symbols_out_pred = vocab[onehot_pred_index]\n", | |
"print(\"%s -> %s\" % (symbols_in,symbols_out_pred))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Done with tons of help from this article.