Created
November 13, 2016 16:01
-
-
Save aazout/3634fca90db6d4523aee9c3e40b05681 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
#import tf_attached_input | |
import tf_attached as attached | |
import tf_attached_input as attached_input | |
#import math, time | |
#from datetime import datetime | |
import numpy as np | |
import os | |
from tf_attached_freezegraph import freeze_graph | |
import PIL | |
from PIL import Image | |
FLAGS = tf.app.flags.FLAGS | |
tf.app.flags.DEFINE_string('checkpoint_dir', '/data/attached', """Directory where to read model checkpoints.""") | |
def save_graph(): | |
print("Creating Attached inference graph...") | |
with tf.Graph().as_default() as g: | |
image = tf.placeholder(tf.float32, [None, 32*32*3], name="input_image") | |
reshaped_image = tf.reshape(image, shape=tf.pack([32, 32, 3])) | |
reshaped_image = tf.cast(reshaped_image, tf.float32) | |
#TODO: Commented out because op not supported in iOS!!!! FIND IT | |
#reshaped_image = tf.image.central_crop(reshaped_image, central_fraction=0.875) | |
# Resize the image to the original height and width. | |
reshaped_image = tf.expand_dims(reshaped_image, 0) | |
reshaped_image = tf.image.resize_bilinear(reshaped_image, [32, 32], align_corners=False) | |
reshaped_image = tf.squeeze(reshaped_image, [0]) | |
#reshaped_image = attached_input.crop_and_resize_image(reshaped_image) | |
reshaped_image.set_shape([32, 32, 3]) | |
reshaped_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, 24, 24) | |
float_image = tf.image.per_image_whitening(reshaped_image) | |
# Expand dimension as batch | |
float_image = tf.expand_dims(float_image, 0) | |
logits = attached.inference(float_image, batch_size=1) | |
normalized_logits = tf.nn.softmax(logits, name="logits") | |
# Restore the moving average version of the learned variables for eval. | |
variable_averages = tf.train.ExponentialMovingAverage(attached.MOVING_AVERAGE_DECAY) | |
variables_to_restore = variable_averages.variables_to_restore() | |
saver = tf.train.Saver(variables_to_restore) | |
#saver = tf.train.Saver() | |
print("Reading file...") | |
im = Image.open("/data/attached/documents_crops/2_1.jpg") | |
im = im.resize((32, 32)) | |
im = np.array(im) | |
im = im.ravel() | |
print(im) | |
im = np.reshape(im, (1, im.shape[0])) | |
with tf.Session() as sess: | |
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) | |
if ckpt and ckpt.model_checkpoint_path: | |
# Restores from checkpoint | |
saver.restore(sess, ckpt.model_checkpoint_path) | |
# Assuming model_checkpoint_path looks something like: | |
# /my-favorite-path/cifar10_train/model.ckpt-0, | |
# extract global_step from it. | |
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] | |
logits = sess.run(normalized_logits, feed_dict={image:im}) | |
print("Logits:", logits) | |
print("Saving and freezing the graph for", 'model.ckpt' + "-" + str(global_step)) | |
# Save the graph | |
tf.train.write_graph(g.as_graph_def(), FLAGS.data_dir, "input_attached_graph.pb") | |
# Freeze the graph | |
input_graph_path = os.path.join(FLAGS.data_dir, "input_attached_graph.pb") | |
input_saver_def_path = "" | |
input_binary = False | |
input_checkpoint_path = os.path.join(FLAGS.data_dir, 'model.ckpt' + "-" + str(global_step)) | |
output_node_names = "logits" | |
restore_op_name = "save/restore_all" | |
filename_tensor_name = "save/Const:0" | |
output_graph_path = os.path.join(FLAGS.data_dir, "output_attached_graph.pb") | |
clear_devices = False | |
freeze_graph(input_graph_path, | |
input_saver_def_path, | |
input_binary, | |
input_checkpoint_path, | |
output_node_names, | |
restore_op_name, | |
filename_tensor_name, | |
output_graph_path, | |
clear_devices, | |
"") | |
def main(argv=None): # pylint: disable=unused-argument | |
save_graph() | |
if __name__ == '__main__': | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment