Skip to content

Instantly share code, notes, and snippets.

@aazout
Created November 13, 2016 16:01
Show Gist options
  • Save aazout/3634fca90db6d4523aee9c3e40b05681 to your computer and use it in GitHub Desktop.
Save aazout/3634fca90db6d4523aee9c3e40b05681 to your computer and use it in GitHub Desktop.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
#import tf_attached_input
import tf_attached as attached
import tf_attached_input as attached_input
#import math, time
#from datetime import datetime
import numpy as np
import os
from tf_attached_freezegraph import freeze_graph
import PIL
from PIL import Image
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('checkpoint_dir', '/data/attached', """Directory where to read model checkpoints.""")
def save_graph():
print("Creating Attached inference graph...")
with tf.Graph().as_default() as g:
image = tf.placeholder(tf.float32, [None, 32*32*3], name="input_image")
reshaped_image = tf.reshape(image, shape=tf.pack([32, 32, 3]))
reshaped_image = tf.cast(reshaped_image, tf.float32)
#TODO: Commented out because op not supported in iOS!!!! FIND IT
#reshaped_image = tf.image.central_crop(reshaped_image, central_fraction=0.875)
# Resize the image to the original height and width.
reshaped_image = tf.expand_dims(reshaped_image, 0)
reshaped_image = tf.image.resize_bilinear(reshaped_image, [32, 32], align_corners=False)
reshaped_image = tf.squeeze(reshaped_image, [0])
#reshaped_image = attached_input.crop_and_resize_image(reshaped_image)
reshaped_image.set_shape([32, 32, 3])
reshaped_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, 24, 24)
float_image = tf.image.per_image_whitening(reshaped_image)
# Expand dimension as batch
float_image = tf.expand_dims(float_image, 0)
logits = attached.inference(float_image, batch_size=1)
normalized_logits = tf.nn.softmax(logits, name="logits")
# Restore the moving average version of the learned variables for eval.
variable_averages = tf.train.ExponentialMovingAverage(attached.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
#saver = tf.train.Saver()
print("Reading file...")
im = Image.open("/data/attached/documents_crops/2_1.jpg")
im = im.resize((32, 32))
im = np.array(im)
im = im.ravel()
print(im)
im = np.reshape(im, (1, im.shape[0]))
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
# Assuming model_checkpoint_path looks something like:
# /my-favorite-path/cifar10_train/model.ckpt-0,
# extract global_step from it.
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
logits = sess.run(normalized_logits, feed_dict={image:im})
print("Logits:", logits)
print("Saving and freezing the graph for", 'model.ckpt' + "-" + str(global_step))
# Save the graph
tf.train.write_graph(g.as_graph_def(), FLAGS.data_dir, "input_attached_graph.pb")
# Freeze the graph
input_graph_path = os.path.join(FLAGS.data_dir, "input_attached_graph.pb")
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = os.path.join(FLAGS.data_dir, 'model.ckpt' + "-" + str(global_step))
output_node_names = "logits"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = os.path.join(FLAGS.data_dir, "output_attached_graph.pb")
clear_devices = False
freeze_graph(input_graph_path,
input_saver_def_path,
input_binary,
input_checkpoint_path,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph_path,
clear_devices,
"")
def main(argv=None): # pylint: disable=unused-argument
save_graph()
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment