Created
January 20, 2019 16:00
-
-
Save jzstark/aa2a596b6015bc8d7fc8bf58058392da to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import math | |
| import os | |
| import sys | |
| import time | |
| import tensorflow as tf | |
| from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets | |
| NUM_CLASSES = 10 | |
| IMAGE_SIZE = 28 | |
| IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE | |
| def inference(images, hidden1_units, hidden2_units): | |
| """Build the MNIST model up to where it may be used for inference. | |
| Args: | |
| images: Images placeholder, from inputs(). | |
| hidden1_units: Size of the first hidden layer. | |
| hidden2_units: Size of the second hidden layer. | |
| Returns: | |
| softmax_linear: Output tensor with the computed logits. | |
| """ | |
| # Hidden 1 | |
| with tf.name_scope('hidden1'): | |
| weights = tf.Variable( | |
| tf.truncated_normal([IMAGE_PIXELS, hidden1_units], | |
| stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), | |
| name='weights') | |
| biases = tf.Variable(tf.zeros([hidden1_units]), | |
| name='biases') | |
| hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases) | |
| # Hidden 2 | |
| with tf.name_scope('hidden2'): | |
| weights = tf.Variable( | |
| tf.truncated_normal([hidden1_units, hidden2_units], | |
| stddev=1.0 / math.sqrt(float(hidden1_units))), | |
| name='weights') | |
| biases = tf.Variable(tf.zeros([hidden2_units]), | |
| name='biases') | |
| hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases) | |
| # Linear | |
| with tf.name_scope('softmax_linear'): | |
| weights = tf.Variable( | |
| tf.truncated_normal([hidden2_units, NUM_CLASSES], | |
| stddev=1.0 / math.sqrt(float(hidden2_units))), | |
| name='weights') | |
| biases = tf.Variable(tf.zeros([NUM_CLASSES]), | |
| name='biases') | |
| logits = tf.matmul(hidden2, weights) + biases | |
| return logits | |
| def loss(logits, labels): | |
| labels = tf.to_int64(labels) | |
| return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) | |
| def training(loss, learning_rate): | |
| # Add a scalar summary for the snapshot loss. | |
| tf.summary.scalar('loss', loss) | |
| # Create the gradient descent optimizer with the given learning rate. | |
| optimizer = tf.train.GradientDescentOptimizer(learning_rate) | |
| # Create a variable to track the global step. | |
| global_step = tf.Variable(0, name='global_step', trainable=False) | |
| # Use the optimizer to apply the gradients that minimize the loss | |
| # (and also increment the global step counter) as a single training step. | |
| train_op = optimizer.minimize(loss, global_step=global_step) | |
| return train_op | |
| def evaluation(logits, labels, name): | |
| correct = tf.nn.in_top_k(logits, labels, 1) | |
| return tf.reduce_sum(tf.cast(correct, tf.int32), name=name) | |
| input_data_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/input_data') | |
| data_sets = read_data_sets(input_data_dir, False) | |
| log_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/logs/fully_connected_feed'), | |
| batch_size = 100 | |
| learning_rate = 0.01 | |
| hidden1 = 128 | |
| hidden2 = 32 | |
| max_steps = 2000 | |
| checkpoint_file = os.path.join(os.getenv('HOME'), 'Tmp/tf_converter/model.ckpt') | |
| ### Utils | |
| def placeholder_inputs(batch_size): | |
| images_placeholder = tf.placeholder( | |
| tf.float32, shape=(batch_size, IMAGE_PIXELS), name='images_placeholder') | |
| labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size), name='labels_placeholder') | |
| return images_placeholder, labels_placeholder | |
| def fill_feed_dict(data_set, images_pl, labels_pl): | |
| images_feed, labels_feed = data_set.next_batch(batch_size, False) | |
| feed_dict = { | |
| images_pl: images_feed, | |
| labels_pl: labels_feed, | |
| } | |
| return feed_dict | |
| def do_eval(sess, | |
| eval_correct, | |
| images_placeholder, | |
| labels_placeholder, | |
| data_set): | |
| # And run one epoch of eval. | |
| true_count = 0 # Counts the number of correct predictions. | |
| steps_per_epoch = data_set.num_examples // batch_size | |
| num_examples = steps_per_epoch * batch_size | |
| for step in xrange(steps_per_epoch): | |
| feed_dict = fill_feed_dict(data_set, | |
| images_placeholder, | |
| labels_placeholder) | |
| true_count += sess.run(eval_correct, feed_dict=feed_dict) | |
| precision = float(true_count) / num_examples | |
| print('Num examples: %d Num correct: %d Precision @ 1: %0.04f' % | |
| (num_examples, true_count, precision)) | |
| ## Start training | |
| """ | |
| with tf.Graph().as_default(): | |
| # summary = tf.summary.merge_all() | |
| # summary_writer = tf.summary.FileWriter(log_dir, sess.graph) | |
| sess = tf.Session() | |
| images_placeholder, labels_placeholder = placeholder_inputs(batch_size) | |
| logits = inference(images_placeholder, 128, 32) | |
| loss = loss(logits, labels_placeholder) | |
| train_op = training(loss, learning_rate) | |
| eval_correct = evaluation(logits, labels_placeholder, 'eval_correct') | |
| saver = tf.train.Saver() | |
| init = tf.global_variables_initializer() | |
| sess.run(init) | |
| for step in xrange(max_steps): | |
| start_time = time.time() | |
| feed_dict = fill_feed_dict( | |
| data_sets.train, images_placeholder, labels_placeholder) | |
| _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) | |
| duration = time.time() - start_time | |
| if step % 100 == 0: | |
| print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) | |
| # Update the events file. | |
| # summary_str = sess.run(summary, feed_dict=feed_dict) | |
| # summary_writer.add_summary(summary_str, step) | |
| # summary_writer.flush() | |
| if (step + 1) % 1000 == 0 or (step + 1) == max_steps: | |
| # saver.save(sess, checkpoint_file, global_step=step) | |
| print('Test Data Eval:') | |
| do_eval(sess, eval_correct, images_placeholder, | |
| labels_placeholder, data_sets.test) | |
| saver.save(sess, checkpoint_file) | |
| """ | |
| # Load and save | |
| meta_file = checkpoint_file + '.meta' | |
| with tf.Graph().as_default(): | |
| sess = tf.Session() | |
| new_saver = tf.train.import_meta_graph(meta_file) | |
| # g = new_saver.export_meta_graph() | |
| # print(g) # python mnist.py > fuck.txt | |
| # g.SerializeToString() | |
| # g.ListFields() | |
| new_saver.restore(sess, checkpoint_file) | |
| graph = tf.get_default_graph() | |
| # graph.get_operations() | |
| images_placeholder = graph.get_tensor_by_name('images_placeholder:0') | |
| labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0') | |
| eval_correct = graph.get_tensor_by_name('eval_correct:0') | |
| print('Test Data Eval after restoring model:') | |
| do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test) | |
| #==== | |
| #from tensorflow.python.tools import inspect_checkpoint as chkp | |
| #chkp.print_tensors_in_checkpoint_file("model.ckpt", tensor_name='', all_tensors=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment