Last active
June 22, 2016 20:59
-
-
Save wpm/b61e281ea380280c60cd6c872044e9ca to your computer and use it in GitHub Desktop.
Minimal TensorFlow Example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A minimal implementation of the MNIST handwritten digits classification task in TensorFlow. | |
This runs MNIST images images through a single hidden layer and softmax loss function. | |
It demonstrates in a single Python source file the basics of creating a model, training and evaluating data sets, and | |
writing summaries that can be visualized by TensorBoard. | |
""" | |
from __future__ import division | |
import math | |
import tensorflow as tf | |
from six.moves import xrange as range | |
from tensorflow.examples.tutorials.mnist import input_data | |
PIXELS = 28 * 28 | |
HIDDEN = 128 | |
BATCH_SIZE = 50 | |
LEARNING_RATE = 0.01 | |
REPORT_INTERVAL = 100 | |
SUMMARY_DIRECTORY = "summary" | |
def epoch(data, operations): | |
""" | |
Iterate one epoch of a data set in batches of size BATCH_SIZE through specified operations in the graph. | |
:param data: the data to iterate | |
:type data: data set defining num_examples and next_batch | |
:param operations: operations in a TensorFlow graph | |
:type operations: list of Operation | |
:return: iteration over the operation results for each batch | |
:rtype: iterator | |
""" | |
for _ in range(data.num_examples // BATCH_SIZE): | |
batch = data.next_batch(BATCH_SIZE) | |
yield session.run(operations, feed_dict={x: batch[0], y: batch[1]}) | |
if tf.gfile.Exists(SUMMARY_DIRECTORY): | |
tf.gfile.DeleteRecursively(SUMMARY_DIRECTORY) | |
tf.gfile.MakeDirs(SUMMARY_DIRECTORY) | |
with tf.Graph().as_default(): | |
with tf.name_scope("Input"): | |
x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, PIXELS], name="input_image") | |
y = tf.placeholder(tf.int64, shape=BATCH_SIZE, name="true_digit") | |
with tf.name_scope("Hidden"): | |
w = tf.Variable( | |
tf.truncated_normal( | |
[PIXELS, HIDDEN], stddev=1.0 / math.sqrt(float(PIXELS)) | |
), | |
name="weights" | |
) | |
b = tf.Variable(tf.zeros([HIDDEN]), name="biases") | |
y_predicted = tf.matmul(x, w) + b | |
with tf.name_scope("Loss"): | |
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(y_predicted, y, name="cross_entropy") | |
loss = tf.reduce_mean(cross_entropy, name="mean_cross_entropy") | |
tf.scalar_summary(loss.op.name, loss) | |
with tf.name_scope("Train"): | |
global_step = tf.Variable(0, name="global_step", trainable=False) | |
training_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(loss, global_step=global_step) | |
with tf.name_scope("Evaluate"): | |
correct = tf.reduce_sum(tf.cast(tf.nn.in_top_k(y_predicted, y, 1), tf.int32), name="correct") | |
summary = tf.merge_all_summaries() | |
mnist = input_data.read_data_sets("MNIST_data") | |
with tf.Session() as session: | |
train_writer = tf.train.SummaryWriter(SUMMARY_DIRECTORY, session.graph) | |
session.run(tf.initialize_all_variables()) | |
# Train the model, periodically evaluating on the validation set. | |
for i, l, s, _ in epoch(mnist.train, [global_step, loss, summary, training_step]): | |
if i % REPORT_INTERVAL == 0: | |
total_correct = sum(c for c in epoch(mnist.validation, correct)) | |
print("Iteration %d: Training loss %0.5f, Validation correct %0.5f" % | |
(i, l, total_correct / mnist.validation.num_examples)) | |
train_writer.add_summary(s, global_step=i) | |
train_writer.flush() | |
# Run the model on test data. | |
total_correct = sum(c for c in epoch(mnist.test, correct)) | |
print("Test set correct: %0.5f" % (total_correct / mnist.test.num_examples)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment