Skip to content

Instantly share code, notes, and snippets.

@yujuwon
Created April 27, 2016 08:40
Show Gist options
  • Select an option

  • Save yujuwon/b9bd415439484c9778cbb233f456ccfd to your computer and use it in GitHub Desktop.

Select an option

Save yujuwon/b9bd415439484c9778cbb233f456ccfd to your computer and use it in GitHub Desktop.
import tensorflow as tf
import sys
sys.path.append("/root/work/deep/code/01_mnist_beginning")
import input_data
import numpy
import math
import time
NUM_CLASSES = 10
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
def inference(images, hidden1_units, hidden2_units):
with tf.name_scope('hidden1'):
weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
name='weight')
biases = tf.Variable(tf.zeros([hidden1_units]),
name='biases')
hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
with tf.name_scope('hidden2'):
weights = tf.Variable(tf.truncated_normal([hidden1_units, hidden2_units],
stddev=1.0 / math.sqrt(float(hidden1_units))),
name='weight')
biases = tf.Variable(tf.zeros([hidden2_units]),
name='biases')
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
with tf.name_scope('softmax_linear'):
weights = tf.Variable(tf.truncated_normal([hidden2_units, NUM_CLASSES],
stddev=1.0 / math.sqrt(float(hidden2_units))),
name='weights')
biases = tf.Variable(tf.zeros([NUM_CLASSES]),
name='biases')
logits = tf.matmul(hidden2, weights) + biases
return logits
def loss(logits, labels):
labels = tf.to_int64(labels)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='xentropy')
loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
return loss
def training(loss, learning_rate):
tf.scalar_summary(loss.op.name, loss)
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
return train_op
def do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_set):
true_count = 0
steps_per_epoch = data_set.num_examples / 100
num_examples = steps_per_epoch * 100
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)
precision = true_count / (float)(num_examples)
print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' %(num_examples, true_count, precision))
def evaluation(logits, labels):
correct = tf.nn.in_top_k(logits, labels, 1)
return tf.reduce_sum(tf.cast(correct, tf.int32))
def placeholder_inputs(batch_size):
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
return images_placeholder, labels_placeholder
def fill_feed_dict(data_set, images_pl, labels_pl):
images_feed, labels_feed = data_set.next_batch(100, False)
feed_dict = {
images_pl: images_feed,
labels_pl: labels_feed,
}
return feed_dict
def run_training():
data_sets = input_data.read_data_sets("MNIST_data/", False)
with tf.Graph().as_default():
images_placeholder, labels_placeholder = placeholder_inputs(100)
logits = inference(images_placeholder, 128, 32)
loss_function = loss(logits, labels_placeholder)
train_op = training(loss_function, 0.01)
eval_correct = evaluation(logits, labels_placeholder)
summary_op = tf.merge_all_summaries()
saver = tf.train.Saver()
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
summary_writer = tf.train.SummaryWriter("./logs", sess.graph)
for step in xrange(2000):
start_time = time.time()
feed_dict = fill_feed_dict(data_sets.train,
images_placeholder,
labels_placeholder)
_, loss_value = sess.run([train_op, loss_function],
feed_dict=feed_dict)
duration = time.time() - start_time
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' %(step, loss_value, duration))
summary_str = sess.run(summary_op, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, step)
summary_writer.flush()
if (step + 1) % 1000 == 0 or (step + 1) == 2000:
saver.save(sess, "./logs", global_step=step)
print('Training Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.validation)
print('Test Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_sets.test)
run_training()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment