Skip to content

Instantly share code, notes, and snippets.

@Echooff3
Created September 1, 2017 16:10
Show Gist options
  • Save Echooff3/a8ce307c3f007bc3c770461e339fa65f to your computer and use it in GitHub Desktop.
Save Echooff3/a8ce307c3f007bc3c770461e339fa65f to your computer and use it in GitHub Desktop.
RNN LSTM - The next number from 0 - 9
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.contrib import rnn\n",
"import random\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"vocab = [i for i in range(0,10)]\n",
"vocab_size = len(vocab)\n",
"n_input = 1\n",
"# number of units in RNN cell\n",
"n_hidden = 64\n",
"\n",
"# tf Graph input\n",
"x = tf.placeholder(\"float\", [None, n_input, 1])\n",
"y = tf.placeholder(\"float\", [None, vocab_size])\n",
"\n",
"# RNN output node weights and biases\n",
"weights = {\n",
" 'out': tf.Variable(tf.random_normal([n_hidden, vocab_size]),name=\"w1\")\n",
"}\n",
"biases = {\n",
" 'out': tf.Variable(tf.random_normal([vocab_size]),name=\"b1\")\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def RNN(x, weights, biases):\n",
"\n",
" # reshape to [1, n_input]\n",
" x = tf.reshape(x, [-1, n_input])\n",
"\n",
" # Generate a n_input-element sequence of inputs\n",
" # (eg. [had] [a] [general] -> [20] [6] [33])\n",
" x = tf.split(x,n_input,1)\n",
"\n",
" # 2-layer LSTM, each layer has n_hidden units.\n",
" # Average Accuracy= 95.20% at 50k iter\n",
" rnn_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(n_hidden),rnn.BasicLSTMCell(n_hidden)])\n",
"\n",
" # 1-layer LSTM with n_hidden units but with lower accuracy.\n",
" # Average Accuracy= 90.60% 50k iter\n",
" # Uncomment line below to test but comment out the 2-layer rnn.MultiRNNCell above\n",
" # rnn_cell = rnn.BasicLSTMCell(n_hidden)\n",
"\n",
" # generate prediction\n",
" outputs, states = rnn.static_rnn(rnn_cell, x, dtype=tf.float32)\n",
"\n",
" # there are n_input outputs but\n",
" # we only want the last output\n",
" return tf.add(tf.matmul(outputs[-1], weights['out']), biases['out'], \"network_out\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"pred = RNN(x, weights, biases)\n",
"\n",
"# Loss and optimizer\n",
"learning_rate = 0.001\n",
"cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))\n",
"optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate).minimize(cost)\n",
"\n",
"# Model evaluation\n",
"correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))\n",
"accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))\n",
"\n",
"# Initializing the variables\n",
"init = tf.global_variables_initializer()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"session = tf.Session()\n",
"session.run(init)\n",
"step = 0\n",
"offset = random.randint(0,n_input+1)\n",
"end_offset = n_input + 1\n",
"acc_total = 0\n",
"loss_total = 0\n",
"training_iters = 5000\n",
"display_step = 1000"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iter= 1000, Average Loss= 1.681965, Average Accuracy= 36.20%\n",
"[2] - [3] vs [3]\n",
"Iter= 2000, Average Loss= 0.598271, Average Accuracy= 81.00%\n",
"[5] - [6] vs [7]\n",
"Iter= 3000, Average Loss= 0.237745, Average Accuracy= 94.60%\n",
"[4] - [5] vs [5]\n",
"Iter= 4000, Average Loss= 0.089085, Average Accuracy= 99.30%\n",
"[2] - [3] vs [3]\n",
"Iter= 5000, Average Loss= 0.027395, Average Accuracy= 99.90%\n",
"[8] - [9] vs [9]\n"
]
}
],
"source": [
"while step < training_iters:\n",
" # training\n",
" if offset > (len(vocab)-end_offset):\n",
" offset = random.randint(0, n_input+1)\n",
"\n",
" symbols_in_keys = [[vocab[i]] for i in range(offset, offset+n_input)]\n",
" symbols_in_keys = np.reshape(np.array(symbols_in_keys), [-1, n_input, 1])\n",
"\n",
" symbols_in_keys\n",
"\n",
" symbols_out_onehot = np.zeros([vocab_size], dtype=float)\n",
" symbols_out_onehot[offset+n_input] = 1.0\n",
" symbols_out_onehot = np.reshape(symbols_out_onehot,[1,-1])\n",
"\n",
" symbols_out_onehot\n",
"\n",
" #move the graph\n",
" _, acc, loss, onehot_pred = session.run([optimizer, accuracy, cost, pred], \\\n",
" feed_dict={x: symbols_in_keys, y: symbols_out_onehot})\n",
" loss_total += loss\n",
" acc_total += acc\n",
"\n",
" #results\n",
" if (step+1) % display_step == 0:\n",
" print(\"Iter= \" + str(step+1) + \", Average Loss= \" + \\\n",
" \"{:.6f}\".format(loss_total/display_step) + \", Average Accuracy= \" + \\\n",
" \"{:.2f}%\".format(100*acc_total/display_step))\n",
" acc_total = 0\n",
" loss_total = 0\n",
" symbols_in = [vocab[i] for i in range(offset, offset + n_input)]\n",
" symbols_out = vocab[offset + n_input]\n",
" symbols_out_pred = vocab[int(tf.argmax(onehot_pred, 1).eval(session=session))]\n",
" print(\"%s - [%s] vs [%s]\" % (symbols_in,symbols_out,symbols_out_pred))\n",
" offset += (n_input+1)\n",
" step += 1"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2 -> 3\n"
]
}
],
"source": [
"symbols_in = vocab[random.randint(0,n_input+1)]\n",
"keys = np.reshape(np.array([symbols_in]), [-1, n_input, 1])\n",
"onehot_pred = session.run(pred, feed_dict={x: keys})\n",
"onehot_pred_index = int(tf.argmax(onehot_pred, 1).eval(session=session))\n",
"symbols_out_pred = vocab[onehot_pred_index]\n",
"print(\"%s -> %s\" % (symbols_in,symbols_out_pred))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@Echooff3
Copy link
Author

Echooff3 commented Sep 1, 2017

Done with tons of help from this article.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment