Last active
April 19, 2019 02:10
-
-
Save apivovarov/2ee1723f4fdcf9e711b0e9c3fcc342f0 to your computer and use it in GitHub Desktop.
TF resnet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import tensorflow as tf | |
| import numpy as np | |
| import time | |
| from imagenet_preprocessing import preprocess_image | |
| from imagenet1000 import imagenet_classes | |
| ms = lambda: int(round(time.time() * 1000)) | |
| model_path = "resnet_v1_fp32_savedmodel_NCHW/1538686577" | |
| model_path = "resnet_v1_fp32_savedmodel_NHWC/1538686669" | |
| model_path = "resnet_v1_fp32_savedmodel_NHWC.pb" | |
| #model_path = '20180601_resnet_v1_imagenet_savedmodel/1527888778' | |
| PREFIX = 'aimport' | |
| def eval_image(im_path): | |
| image = tf.read_file(im_path) | |
| im_tensor0 = preprocess_image(image_buffer=image, bbox=None, output_height=224, output_width=224, num_channels=3, is_training=False) | |
| #print(type(im_tensor0)) | |
| im0 = im_tensor0.eval() | |
| #print(type(im0)) | |
| print(im0.shape, im_path) | |
| return im0 | |
| def get_tag_set(saved_model_dir): | |
| """Return the tag set of saved model, multiple metagraphs are not supported""" | |
| from tensorflow.contrib.saved_model.python.saved_model import reader | |
| tag_sets = reader.get_saved_model_tag_sets(saved_model_dir) | |
| return tag_sets[0] | |
| def get_input_and_output_names_from_meta(meta_graph_def): | |
| output_names = set() | |
| input_names = set() | |
| for sig_def in meta_graph_def.signature_def.values(): | |
| for output_tensor in sig_def.outputs.values(): | |
| output_names.add(output_tensor.name) | |
| for input_tensor in sig_def.inputs.values(): | |
| input_names.add(input_tensor.name) | |
| return list(input_names), list(output_names) | |
| def get_input_and_output_names(graph): | |
| input_tensor_names = [] | |
| output_tensor_names = set() | |
| for op in graph.get_operations(): | |
| if not op.name.startswith("{}/".format(PREFIX)): | |
| continue | |
| #print(op.name, op.type, op.graph, op.inputs[0] if op.inputs.__len__() > 0 else None, op.outputs)#, op.inputs.__len__(), op.outputs.__len__()) | |
| if op.type == 'Placeholder' and op.inputs.__len__() == 0 and op.outputs.__len__() == 1: | |
| input_tensor_names.append(op.outputs[0].name) | |
| if op.outputs.__len__() == 1: | |
| output_tensor_names.add(op.outputs[0].name) | |
| # print(len(all_out_tensor_names)) | |
| for op in graph.get_operations(): | |
| for in_t in op.inputs: | |
| if in_t.name in output_tensor_names: | |
| output_tensor_names.remove(in_t.name) | |
| return input_tensor_names, output_tensor_names | |
| def print_imagenet_labels(res): | |
| idx = 0 | |
| for cl_id in res: | |
| print("id: {}, class: {}, label: {}".format(idx, cl_id, imagenet_classes[cl_id])) | |
| idx += 1 | |
| def load_frozen_model(frozen_model_file): | |
| from tensorflow.python.platform import gfile | |
| print("Loading frozen model: {} ....".format(frozen_model_file)) | |
| with gfile.FastGFile(frozen_model_file, 'rb') as f: | |
| graph_def = tf.GraphDef() | |
| graph_def.ParseFromString(f.read()) | |
| print(type(graph_def)) | |
| with tf.Graph().as_default() as graph: | |
| tf.import_graph_def(graph_def, name=PREFIX) | |
| return graph | |
| def load_saved_model(saved_model_dir): | |
| tags = get_tag_set(saved_model_dir) | |
| print("tags: {}".format(tags)) | |
| print("Loading saved model: {} ....".format(saved_model_dir)) | |
| meta_graph_def = tf.saved_model.loader.load(sess, tags, saved_model_dir) | |
| print(type(meta_graph_def)) | |
| return meta_graph_def | |
| def get_input(input_shape): | |
| with tf.Session(graph=tf.Graph()) as sess: | |
| inp = np.random.random(input_shape) | |
| inp[0] = eval_image('chi.jpg') | |
| inp[1] = eval_image('cat.jpg') | |
| inp[2] = eval_image('canoe.jpg') | |
| inp[3] = eval_image('goldfish.jpg') | |
| inp[4] = eval_image('shark.jpg') | |
| inp[5] = eval_image('shark2.jpg') | |
| inp[6] = eval_image('croc.jpg') | |
| inp[7] = eval_image('lizz.jpg') | |
| inp[8] = eval_image('kite.jpg') | |
| #np.save('cat.npy', inp[1]) | |
| return inp | |
| def save_logdir(sess, logdir): | |
| file_writer = tf.summary.FileWriter(logdir, sess.graph) | |
| print("tensorboard --logdir {}".format(logdir)) | |
| inp = get_input((64, 224, 224, 3)) | |
| if model_path.endswith(".pb"): | |
| m_type = 'F' | |
| else: | |
| m_type = 'S' | |
| if m_type == 'S': | |
| graph = tf.Graph() | |
| elif m_type == 'F': | |
| graph = load_frozen_model(model_path) | |
| input_tensor_names, output_tensor_names = get_input_and_output_names(graph) | |
| with tf.Session(graph=graph) as sess: | |
| if m_type == 'S': | |
| meta_graph_def = load_saved_model(model_path) | |
| input_tensor_names, output_tensor_names = get_input_and_output_names_from_meta(meta_graph_def) | |
| graph = tf.get_default_graph() | |
| # print(len(all_out_tensor_names)) | |
| print("input_tensor_names: {}".format(input_tensor_names)) | |
| print("output_tensor_names: {}".format(output_tensor_names)) | |
| input_tensor = graph.get_tensor_by_name(input_tensor_names[0]) | |
| print(input_tensor) | |
| input_shape = input_tensor.shape | |
| print("input_shape: {}".format(input_shape)) | |
| output_tensors = [] | |
| for nm in output_tensor_names: | |
| out_tensor = graph.get_tensor_by_name(nm) | |
| print(out_tensor) | |
| output_tensors.append(out_tensor) | |
| #save_logdir(sess, 'sss') | |
| print("sess.run...") | |
| t1 = ms() | |
| res = sess.run(output_tensors, feed_dict={input_tensor: inp}) | |
| t2 = ms() | |
| print("sess.run done") | |
| print("duration {0:,} ms".format(t2-t1)) | |
| print(res) | |
| for out in res: | |
| if len(out.shape) == 1 or out.shape[1] == 1: | |
| print_imagenet_labels(out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment