Created
June 26, 2018 11:49
-
-
Save vishalghor/102394110cd74cc3d76aebf8cf8a367b to your computer and use it in GitHub Desktop.
Tensorflow Mobilenet_v1 inference script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
from __future__ import absolute_import | |
from __future__ import print_function | |
import argparse | |
import sys | |
import PIL | |
from PIL import Image | |
import tensorflow as tf | |
import numpy as np | |
from numpy import array | |
''' | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--image', required=True, type=str, help='Absolute path to image file.') | |
parser.add_argument( | |
'--num_top_predictions', | |
type=int, | |
default=5, | |
help='Display this many predictions.') | |
parser.add_argument( | |
'--graph', | |
required=True, | |
type=str, | |
help='Absolute path to graph file (.pb)') | |
parser.add_argument( | |
'--labels', | |
required=True, | |
type=str, | |
help='Absolute path to labels file (.txt)') | |
parser.add_argument( | |
'--output_layer', | |
type=str, | |
default='final_result:0', | |
help='Name of the result operation') | |
parser.add_argument( | |
'--input_layer', | |
type=str, | |
default='DecodeJpeg/contents:0', | |
help='Name of the input operation') | |
''' | |
def load_image(filename): | |
"""Read in the image_data to be classified.""" | |
#return tf.gfile.FastGFile(filename, 'rb').read() | |
im=Image.open(filename) | |
image=np.asarray(im, dtype="float32") | |
img=(image-128.0)/128.0 | |
img=array(img).reshape(1,224,224,3) | |
return img | |
def load_labels(filename): | |
"""Read in labels, one label per line.""" | |
return [line.rstrip() for line in tf.gfile.GFile(filename)] | |
def load_graph(filename): | |
"""Unpersists graph from file as default graph.""" | |
with tf.gfile.FastGFile('/home/ubuntu/ML_GIT/TensorFlow/MobileNet_Multi_label/output_graph.pb', 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
tf.import_graph_def(graph_def, name='') | |
def run_graph(image_data, labels, input_layer_name, output_layer_name, | |
num_top_predictions): | |
with tf.Session() as sess: | |
# Feed the image_data as input to the graph. | |
# predictions will contain a two-dimensional array, where one | |
# dimension represents the input image count, and the other has | |
# predictions per class | |
input_node = sess.graph.get_tensor_by_name('input:0') | |
output_node = sess.graph.get_tensor_by_name('final_result:0') | |
softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name) | |
predictions, = sess.run(softmax_tensor, {input_layer_name: image_data}) | |
# Sort to show labels in order of confidence | |
top_k = predictions.argsort()[-num_top_predictions:][::-1] | |
for node_id in top_k: | |
human_string = labels[node_id] | |
score = predictions[node_id] | |
print('%s (score = %.5f)' % (human_string, score)) | |
return 0 | |
def main(argv): | |
"""Runs inference on an image.""" | |
''' | |
if argv[1:]: | |
raise ValueError('Unused Command Line Args: %s' % argv[1:]) | |
if not tf.gfile.Exists(FLAGS.image): | |
tf.logging.fatal('image file does not exist %s', FLAGS.image) | |
if not tf.gfile.Exists(FLAGS.labels): | |
tf.logging.fatal('labels file does not exist %s', FLAGS.labels) | |
if not tf.gfile.Exists(FLAGS.graph): | |
tf.logging.fatal('graph file does not exist %s', FLAGS.graph) | |
''' | |
input_layer='input:0' | |
output_layer='final_result:0' | |
num_top_predictions=5 | |
# load image | |
image_data = load_image('path-to-image/9_56.JPG') | |
# load labels | |
labels = load_labels('path-to/labels.txt') | |
# load graph, which is stored in the default session | |
load_graph('path-to/output_graph.pb') | |
run_graph(image_data, labels,input_layer,output_layer, | |
num_top_predictions) | |
if __name__ == '__main__': | |
#FLAGS, unparsed = parser.parse_known_args() | |
tf.app.run(main=main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This script is an update to the label_image script provide for inference at tensorflow/tensorflow/image_retraining/label_image.py.