Created
April 9, 2019 06:07
-
-
Save vishal-keshav/94217d4cc5fbd8d1a0790434161b4dc4 to your computer and use it in GitHub Desktop.
Infer the output from a tensorflow pb model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
def preprocess(img): | |
# Apply any preprocessing on the input | |
return img | |
def inference_from_pb(pb_file = "model.pb", img, inputs = ['input'], outputs = ['output']): | |
img = preporcess(img) | |
with tf.gfile.GFile(pb_file, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
with tf.Graph().as_default() as graph: | |
tf.import_graph_def(graph_def) | |
#for op in graph.get_operations(): | |
# print(op.name) | |
with tf.Session(graph = graph) as sess: | |
graph_input = graph.get_tensor_by_name(inputs[0]) | |
graph_output = graph.get_tensor_by_name(outputs[0]) | |
#tf.global_variables_initializer() | |
output = sess.run(graph_output, feed_dict = {graph_input: img}) | |
return output | |
def main(): | |
img = np.array([[1,2],[3,4]]) | |
out = inference_from_pb(img) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment