-
-
Save jonathanjouty/6d335e7bd8c0dcc39c71fee7af540a96 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
A Recurrent Neural Network (LSTM) implementation example using TensorFlow library. | |
This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/) | |
Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf | |
Author: Aymeric Damien | |
Project: https://github.com/aymericdamien/TensorFlow-Examples/ | |
''' | |
from __future__ import print_function | |
import tensorflow as tf | |
from tensorflow.contrib import rnn | |
# Import MNIST data | |
from tensorflow.examples.tutorials.mnist import input_data | |
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) | |
''' | |
To classify images using a recurrent neural network, we consider every image | |
row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then | |
handle 28 sequences of 28 steps for every sample. | |
''' | |
# Parameters | |
learning_rate = 0.001 | |
training_iters = 100000 | |
batch_size = 128 | |
display_step = 10 | |
# Network Parameters | |
n_input = 28 # MNIST data input (img shape: 28*28) | |
n_steps = 28 # timesteps | |
n_hidden = 128 # hidden layer num of features | |
n_classes = 10 # MNIST total classes (0-9 digits) | |
# tf Graph input | |
x = tf.placeholder("float", [None, n_steps, n_input]) | |
y = tf.placeholder("float", [None, n_classes]) | |
# Define weights | |
weights = { | |
'out': tf.Variable(tf.random_normal([n_hidden, n_classes])) | |
} | |
biases = { | |
'out': tf.Variable(tf.random_normal([n_classes])) | |
} | |
def RNN(x, weights, biases): | |
# Define a lstm cell with tensorflow | |
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden) | |
outputs, states = tf.nn.dynamic_rnn(lstm_cell, x, dtype=tf.float32) | |
# [outputs' size is (?, 28, 128), we want the last element, with index 27 | |
# size of last_output is then (?, 128) | |
last_output = outputs[:, outputs.get_shape()[1] - 1, :] | |
# Linear activation, using rnn inner loop last output | |
return tf.matmul(last_output, weights['out']) + biases['out'] | |
pred = RNN(x, weights, biases) | |
# Define loss and optimizer | |
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) | |
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) | |
# Evaluate model | |
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) | |
# Initializing the variables | |
init = tf.global_variables_initializer() | |
# Launch the graph | |
with tf.Session() as sess: | |
sess.run(init) | |
step = 1 | |
# Keep training until reach max iterations | |
while step * batch_size < training_iters: | |
batch_x, batch_y = mnist.train.next_batch(batch_size) | |
# Reshape data to get 28 seq of 28 elements | |
batch_x = batch_x.reshape((batch_size, n_steps, n_input)) | |
# Run optimization op (backprop) | |
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y}) | |
if step % display_step == 0: | |
# Calculate batch accuracy | |
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y}) | |
# Calculate batch loss | |
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y}) | |
print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + \ | |
"{:.6f}".format(loss) + ", Training Accuracy= " + \ | |
"{:.5f}".format(acc)) | |
step += 1 | |
print("Optimization Finished!") | |
# Calculate accuracy for 128 mnist test images | |
test_len = 128 | |
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input)) | |
test_label = mnist.test.labels[:test_len] | |
print("Testing Accuracy:", \ | |
sess.run(accuracy, feed_dict={x: test_data, y: test_label})) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment