Skip to content

Instantly share code, notes, and snippets.

@rocking5566
Created April 10, 2019 09:21
Show Gist options
  • Save rocking5566/edccdcd37ba8025603d49931df874c95 to your computer and use it in GitHub Desktop.
Save rocking5566/edccdcd37ba8025603d49931df874c95 to your computer and use it in GitHub Desktop.
Get tensorflow pb tensor
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from PIL import Image
import numpy as np
import tensorflow as tf
FLAGS = None
image_path = '/workspace/example/imagenet/test_data/husky.jpg'
label_file = '/workspace/example/imagenet/test_data/labels.txt'
model_path = 'mobilenet_v1_1.0_224_quant.pb'
def load_labels(filename):
my_labels = []
input_file = open(filename, 'r')
for l in input_file:
my_labels.append(l.strip())
return my_labels
def create_graph(model_path):
with tf.gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
def run_inference_on_image(image_path):
# Creates graph from saved GraphDef.
create_graph(model_path)
height = 224
width = 224
img = Image.open(image_path)
img = img.resize((width, height))
# add N dim
img = np.expand_dims(img, axis=0)
img = img - 127.5
img = img / 127.5
x = tf.get_default_graph().get_tensor_by_name('input:0')
with tf.Session() as sess:
# NHWC
input = sess.run(x, feed_dict={x: img})
print('Input')
print(input.shape)
print(input[0, 0, 0:3, :])
print('========================================================================')
for node in tf.get_default_graph().as_graph_def().node:
if 'pointwise/weights_quant/FakeQuantWithMinMaxVars' in node.name \
or 'Conv2d_0/weights_quant/FakeQuantWithMinMaxVars' in node.name \
or 'Conv2d_1c_1x1/weights_quant/FakeQuantWithMinMaxVars' in node.name:
w_tensor = tf.get_default_graph().get_tensor_by_name(node.name + ':0')
# HWIO
w = sess.run(w_tensor, feed_dict={x: img})
print('Conv weight')
print(node.name)
print(w.shape)
if w.shape[3] > 300:
print(w[0, 0, 0, 0:301])
else:
print(w[0, 0, 0, :])
print('========================================================================')
elif 'depthwise/weights_quant/FakeQuantWithMinMaxVars' in node.name:
w_tensor = tf.get_default_graph().get_tensor_by_name(node.name + ':0')
# HWO1
w = sess.run(w_tensor, feed_dict={x: img})
print('Depthwise Conv weight')
print(node.name)
print(w.shape)
if w.shape[2] > 300:
print(w[0, 0, 0:301, 0])
else:
print(w[0, 0, :, 0])
print('========================================================================')
elif 'BatchNorm_Fold/bias' in node.name or 'biases/read' in node.name:
bias_tensor = tf.get_default_graph().get_tensor_by_name(node.name + ':0')
b = sess.run(bias_tensor, feed_dict={x: img})
print('Bias')
print(node.name)
print(b.shape)
if b.shape[0] > 300:
print(b[0:301])
else:
print(b)
print('========================================================================')
elif 'act_quant/FakeQuantWithMinMaxVars' in node.name:
act_tensor = tf.get_default_graph().get_tensor_by_name(node.name + ':0')
# NHWC
act = sess.run(act_tensor, feed_dict={x: img})
print('Activation')
print(node.name)
print(act.shape)
if act.shape[3] > 300:
print(act[0, 0, 0, 0:301])
else:
print(act[0, 0, 0, :])
print('========================================================================')
elif node.name.endswith('pointwise/mul_fold') \
or node.name.endswith('Conv2d_0/mul_fold') \
or node.name.endswith('Conv2d_1c_1x1/weights'):
w_tensor = tf.get_default_graph().get_tensor_by_name(node.name + ':0')
# HWIO
w = sess.run(w_tensor, feed_dict={x: img})
print('Conv weight')
print(node.name)
print(w.shape)
if w.shape[3] > 300:
print(w[0, 0, 0, 0:301])
else:
print(w[0, 0, 0, :])
print('========================================================================')
elif node.name.endswith('_depthwise/mul_fold'):
w_tensor = tf.get_default_graph().get_tensor_by_name(node.name + ':0')
# HWO1
w = sess.run(w_tensor, feed_dict={x: img})
print('Depthwise Conv weight')
print(node.name)
print(w.shape)
if w.shape[2] > 300:
print(w[0, 0, 0:301, 0])
else:
print(w[0, 0, :, 0])
print('========================================================================')
elif node.name.endswith('act_quant/min/read') \
or node.name.endswith('act_quant/max/read'):
act_tensor = tf.get_default_graph().get_tensor_by_name(node.name + ':0')
act = sess.run(act_tensor, feed_dict={x: img})
print(node.name)
print(act)
print('========================================================================')
def main(_):
run_inference_on_image(image_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--num_top_predictions',
type=int,
default=5,
help='Display this many predictions.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment