Skip to content

Instantly share code, notes, and snippets.

@jzstark
Created January 20, 2019 16:00
Show Gist options
  • Select an option

  • Save jzstark/aa2a596b6015bc8d7fc8bf58058392da to your computer and use it in GitHub Desktop.

Select an option

Save jzstark/aa2a596b6015bc8d7fc8bf58058392da to your computer and use it in GitHub Desktop.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import sys
import time
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
NUM_CLASSES = 10
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
def inference(images, hidden1_units, hidden2_units):
"""Build the MNIST model up to where it may be used for inference.
Args:
images: Images placeholder, from inputs().
hidden1_units: Size of the first hidden layer.
hidden2_units: Size of the second hidden layer.
Returns:
softmax_linear: Output tensor with the computed logits.
"""
# Hidden 1
with tf.name_scope('hidden1'):
weights = tf.Variable(
tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
name='weights')
biases = tf.Variable(tf.zeros([hidden1_units]),
name='biases')
hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
# Hidden 2
with tf.name_scope('hidden2'):
weights = tf.Variable(
tf.truncated_normal([hidden1_units, hidden2_units],
stddev=1.0 / math.sqrt(float(hidden1_units))),
name='weights')
biases = tf.Variable(tf.zeros([hidden2_units]),
name='biases')
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
# Linear
with tf.name_scope('softmax_linear'):
weights = tf.Variable(
tf.truncated_normal([hidden2_units, NUM_CLASSES],
stddev=1.0 / math.sqrt(float(hidden2_units))),
name='weights')
biases = tf.Variable(tf.zeros([NUM_CLASSES]),
name='biases')
logits = tf.matmul(hidden2, weights) + biases
return logits
def loss(logits, labels):
labels = tf.to_int64(labels)
return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
def training(loss, learning_rate):
# Add a scalar summary for the snapshot loss.
tf.summary.scalar('loss', loss)
# Create the gradient descent optimizer with the given learning rate.
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# Create a variable to track the global step.
global_step = tf.Variable(0, name='global_step', trainable=False)
# Use the optimizer to apply the gradients that minimize the loss
# (and also increment the global step counter) as a single training step.
train_op = optimizer.minimize(loss, global_step=global_step)
return train_op
def evaluation(logits, labels, name):
correct = tf.nn.in_top_k(logits, labels, 1)
return tf.reduce_sum(tf.cast(correct, tf.int32), name=name)
input_data_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/input_data')
data_sets = read_data_sets(input_data_dir, False)
log_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/logs/fully_connected_feed'),
batch_size = 100
learning_rate = 0.01
hidden1 = 128
hidden2 = 32
max_steps = 2000
checkpoint_file = os.path.join(os.getenv('HOME'), 'Tmp/tf_converter/model.ckpt')
### Utils
def placeholder_inputs(batch_size):
images_placeholder = tf.placeholder(
tf.float32, shape=(batch_size, IMAGE_PIXELS), name='images_placeholder')
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size), name='labels_placeholder')
return images_placeholder, labels_placeholder
def fill_feed_dict(data_set, images_pl, labels_pl):
images_feed, labels_feed = data_set.next_batch(batch_size, False)
feed_dict = {
images_pl: images_feed,
labels_pl: labels_feed,
}
return feed_dict
def do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_set):
# And run one epoch of eval.
true_count = 0 # Counts the number of correct predictions.
steps_per_epoch = data_set.num_examples // batch_size
num_examples = steps_per_epoch * batch_size
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)
precision = float(true_count) / num_examples
print('Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
(num_examples, true_count, precision))
## Start training
"""
with tf.Graph().as_default():
# summary = tf.summary.merge_all()
# summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
sess = tf.Session()
images_placeholder, labels_placeholder = placeholder_inputs(batch_size)
logits = inference(images_placeholder, 128, 32)
loss = loss(logits, labels_placeholder)
train_op = training(loss, learning_rate)
eval_correct = evaluation(logits, labels_placeholder, 'eval_correct')
saver = tf.train.Saver()
init = tf.global_variables_initializer()
sess.run(init)
for step in xrange(max_steps):
start_time = time.time()
feed_dict = fill_feed_dict(
data_sets.train, images_placeholder, labels_placeholder)
_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
duration = time.time() - start_time
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
# Update the events file.
# summary_str = sess.run(summary, feed_dict=feed_dict)
# summary_writer.add_summary(summary_str, step)
# summary_writer.flush()
if (step + 1) % 1000 == 0 or (step + 1) == max_steps:
# saver.save(sess, checkpoint_file, global_step=step)
print('Test Data Eval:')
do_eval(sess, eval_correct, images_placeholder,
labels_placeholder, data_sets.test)
saver.save(sess, checkpoint_file)
"""
# Load and save
meta_file = checkpoint_file + '.meta'
with tf.Graph().as_default():
sess = tf.Session()
new_saver = tf.train.import_meta_graph(meta_file)
# g = new_saver.export_meta_graph()
# print(g) # python mnist.py > fuck.txt
# g.SerializeToString()
# g.ListFields()
new_saver.restore(sess, checkpoint_file)
graph = tf.get_default_graph()
# graph.get_operations()
images_placeholder = graph.get_tensor_by_name('images_placeholder:0')
labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
eval_correct = graph.get_tensor_by_name('eval_correct:0')
print('Test Data Eval after restoring model:')
do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)
#====
#from tensorflow.python.tools import inspect_checkpoint as chkp
#chkp.print_tensors_in_checkpoint_file("model.ckpt", tensor_name='', all_tensors=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment