Last active
May 3, 2018 20:19
-
-
Save DanWahlin/2b0186897e8e5ab7be17c0d8ca86b569 to your computer and use it in GitHub Desktop.
Machine Learning Image Classifier Python Script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Originally created by Lin JungHsuan: https://medium.com/@linjunghsuan/create-a-simple-image-classifier-using-tensorflow-a7061635984a | |
import tensorflow as tf, sys | |
image_path = sys.argv[1] | |
# Read in the image_data | |
image_data = tf.gfile.FastGFile(image_path, 'rb').read() | |
# Loads label file, strips off carriage return | |
label_lines = [line.rstrip() for line | |
in tf.gfile.GFile('./tf_files/retrained_labels.txt')] | |
# Unpersists graph from file | |
with tf.gfile.FastGFile('./tf_files/retrained_graph.pb', 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
_ = tf.import_graph_def(graph_def, name='') | |
# Feed the image_data as input to the graph and get first prediction | |
with tf.Session() as sess: | |
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') | |
predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data}) | |
# Sort to show labels of first prediction in order of confidence | |
top_k = predictions[0].argsort()[-len(predictions[0]):][::-1] | |
for node_id in top_k: | |
human_string = label_lines[node_id] | |
score = predictions[0][node_id] | |
print('%s (score = %.5f)' % (human_string, score)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment