Created
January 20, 2019 16:23
-
-
Save jzstark/9d3b8692f5162f403ca4c6dd6bb22be1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import math | |
| import os | |
| import sys | |
| import time | |
| import tensorflow as tf | |
| from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets | |
| NUM_CLASSES = 10 | |
| IMAGE_SIZE = 28 | |
| IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE | |
| def inference(images, hidden1_units, hidden2_units): | |
| # Hidden 1 | |
| with tf.name_scope('hidden1'): | |
| weights = tf.Variable( | |
| tf.truncated_normal([IMAGE_PIXELS, hidden1_units], | |
| stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), | |
| name='weights') | |
| biases = tf.Variable(tf.zeros([hidden1_units]), | |
| name='biases') | |
| hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases) | |
| # Hidden 2 | |
| with tf.name_scope('hidden2'): | |
| weights = tf.Variable( | |
| tf.truncated_normal([hidden1_units, hidden2_units], | |
| stddev=1.0 / math.sqrt(float(hidden1_units))), | |
| name='weights') | |
| biases = tf.Variable(tf.zeros([hidden2_units]), | |
| name='biases') | |
| hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases) | |
| # Linear | |
| with tf.name_scope('softmax_linear'): | |
| weights = tf.Variable( | |
| tf.truncated_normal([hidden2_units, NUM_CLASSES], | |
| stddev=1.0 / math.sqrt(float(hidden2_units))), | |
| name='weights') | |
| biases = tf.Variable(tf.zeros([NUM_CLASSES]), | |
| name='biases') | |
| logits = tf.matmul(hidden2, weights) + biases | |
| return logits | |
| def loss(logits, labels): | |
| labels = tf.to_int64(labels) | |
| return tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) | |
| def training(loss, learning_rate, name='train_op'): | |
| # tf.summary.scalar('loss', loss) | |
| optimizer = tf.train.GradientDescentOptimizer(learning_rate) | |
| global_step = tf.Variable(0, name='global_step', trainable=False) | |
| train_op = optimizer.minimize(loss, global_step=global_step, name=name) | |
| return train_op | |
| def evaluation(logits, labels, name='eval_correct'): | |
| correct = tf.nn.in_top_k(logits, labels, 1) | |
| return tf.reduce_sum(tf.cast(correct, tf.int32), name=name) | |
| input_data_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/input_data') | |
| data_sets = read_data_sets(input_data_dir, False) | |
| log_dir = os.path.join(os.getenv('TEST_TMPDIR', '/tmp'), 'tensorflow/mnist/logs/fully_connected_feed'), | |
| batch_size = 100 | |
| learning_rate = 0.01 | |
| hidden1 = 128 | |
| hidden2 = 32 | |
| max_steps = 2000 | |
| checkpoint_file = os.path.join(os.getenv('HOME'), 'Tmp/tf_converter/model2.ckpt') | |
| meta_file = checkpoint_file + '.meta' | |
| ### Utils | |
| def placeholder_inputs(batch_size): | |
| images_placeholder = tf.placeholder( | |
| tf.float32, shape=(batch_size, IMAGE_PIXELS), name='images_placeholder') | |
| labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size), name='labels_placeholder') | |
| return images_placeholder, labels_placeholder | |
| def fill_feed_dict(data_set, images_pl, labels_pl): | |
| images_feed, labels_feed = data_set.next_batch(batch_size, False) | |
| feed_dict = { | |
| images_pl: images_feed, | |
| labels_pl: labels_feed, | |
| } | |
| return feed_dict | |
| def do_eval(sess, | |
| eval_correct, | |
| images_placeholder, | |
| labels_placeholder, | |
| data_set): | |
| # And run one epoch of eval. | |
| true_count = 0 # Counts the number of correct predictions. | |
| steps_per_epoch = data_set.num_examples // batch_size | |
| num_examples = steps_per_epoch * batch_size | |
| for step in xrange(steps_per_epoch): | |
| feed_dict = fill_feed_dict(data_set, | |
| images_placeholder, | |
| labels_placeholder) | |
| true_count += sess.run(eval_correct, feed_dict=feed_dict) | |
| precision = float(true_count) / num_examples | |
| print('Num examples: %d Num correct: %d Precision @ 1: %0.04f' % | |
| (num_examples, true_count, precision)) | |
| ## Build Graph and save | |
| with tf.Graph().as_default(): | |
| sess = tf.Session() | |
| images_placeholder, labels_placeholder = placeholder_inputs(batch_size) | |
| logits = inference(images_placeholder, 128, 32) | |
| loss = loss(logits, labels_placeholder) | |
| train_op = training(loss, learning_rate) | |
| eval_correct = evaluation(logits, labels_placeholder) | |
| saver = tf.train.Saver() | |
| init = tf.global_variables_initializer() | |
| sess.run(init) | |
| saver.save(sess, checkpoint_file) | |
| ## Start training | |
| with tf.Graph().as_default(): | |
| sess = tf.Session() | |
| saver = tf.train.import_meta_graph(meta_file) | |
| graph = tf.get_default_graph() | |
| # Not sure about this part though... | |
| init = tf.global_variables_initializer() | |
| sess.run(init) | |
| images_placeholder = graph.get_tensor_by_name('images_placeholder:0') | |
| labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0') | |
| eval_correct = graph.get_tensor_by_name('eval_correct:0') | |
| train_op = tf.get_collection('train_op')[0] | |
| loss = tf.get_collection('losses')[0] | |
| for step in xrange(max_steps): | |
| start_time = time.time() | |
| feed_dict = fill_feed_dict( | |
| data_sets.train, images_placeholder, labels_placeholder) | |
| _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) | |
| duration = time.time() - start_time | |
| if step % 100 == 0: | |
| print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) | |
| if (step + 1) % 1000 == 0 or (step + 1) == max_steps: | |
| print('Test Data Eval:') | |
| do_eval(sess, eval_correct, images_placeholder, | |
| labels_placeholder, data_sets.test) | |
| # saver.save(sess, checkpoint_file) | |
| #==== | |
| #from tensorflow.python.tools import inspect_checkpoint as chkp | |
| #chkp.print_tensors_in_checkpoint_file("model.ckpt", tensor_name='', all_tensors=True) | |
| # But it is possible that backward nodes are not required int the graph before training... right? | |
| # https://www.tensorflow.org/api_guides/python/meta_graph | |
| sess = tf.InteractiveSession() | |
| new_saver = tf.train.import_meta_graph(meta_file) | |
| g = new_saver.export_meta_graph() | |
| print(g) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment