Last active
April 26, 2023 18:37
-
-
Save applenob/977f7627345c4b83149752e2f1c88a50 to your computer and use it in GitHub Desktop.
Load tensorflow model from frozen pb file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
import tensorflow as tf | |
def get_session(): | |
"""load a new session""" | |
config = tf.ConfigProto() | |
config.gpu_options.allow_growth = True | |
return tf.Session(config=config) | |
def load_frozen_graph(frozen_graph_filename): | |
"""load a graph from protocol buffer file""" | |
# We load the protobuf file from the disk and parse it to retrieve the | |
# unserialized graph_def | |
with tf.gfile.GFile(frozen_graph_filename, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
# Block: add this block only when error like | |
# "Input 0 of node X was passed float from Y:0 incompatible with expected float_ref." occur. | |
for node in graph_def.node: | |
if node.op == 'RefSwitch': | |
node.op = 'Switch' | |
for index in range(len(node.input)): | |
if 'moving_' in node.input[index]: | |
node.input[index] = node.input[index] + '/read' | |
elif node.op == 'AssignSub': | |
node.op = 'Sub' | |
if 'use_locking' in node.attr: del node.attr['use_locking'] | |
elif node.op == 'AssignAdd': | |
node.op = 'Add' | |
if 'use_locking' in node.attr: del node.attr['use_locking'] | |
elif node.op == 'Assign': | |
node.op = 'Identity' | |
if 'use_locking' in node.attr: del node.attr['use_locking'] | |
if 'validate_shape' in node.attr: del node.attr['validate_shape'] | |
if len(node.input) == 2: | |
# input0: ref: Should be from a Variable node. May be uninitialized. | |
# input1: value: The value to be assigned to the variable. | |
node.input[0] = node.input[1] | |
del node.input[1] | |
# Block end | |
# Then, we import the graph_def into a new Graph and returns it | |
with tf.Graph().as_default() as graph: | |
# The name var will prefix every op/nodes in your graph | |
# Since we load everything in a new graph, this is not needed | |
tf.import_graph_def(graph_def) | |
return graph | |
def load_graph_session_from_pb(pb_file, print_op=False): | |
"""load graph and session from protocol buffer file""" | |
graph = load_frozen_graph(pb_file) | |
if print_op: | |
for op in graph.get_operations(): | |
print(op.name) | |
with graph.as_default(): | |
sess = get_session() | |
return graph, sess | |
graph, sess = load_graph_session_from_pb("elmo.pb") | |
input = graph.get_operation_by_name("input").outputs[0] | |
output = graph.get_operation_by_name("output").outputs[0] | |
def predict_func(sess, one_batch): | |
output_feeds = [output] | |
feed_dict = {input: one_batch} | |
return sess.run(output_feeds, feed_dict=feed_dict) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment