Last active
September 4, 2019 19:40
-
-
Save tonyreina/80763eecdc660e5b358308e9932fa03c to your computer and use it in GitHub Desktop.
Load TensorFlow protobuf
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import argparse | |
parser = argparse.ArgumentParser( | |
description="Loads TensorFlow protobuf and converts it to saved model", | |
add_help=True, formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument("--filename", required=True, | |
help="the name and path of the HDF5 dataset") | |
parser.add_argument("--input_layer_name", default="import/shuffled_queue:0", | |
help="the name of the input layer - Use Netron or TensorBoard to view the graph") | |
parser.add_argument("--output_layer_name", default="import/Rank_1/packed:0", | |
help="the name of the input layer - Use Netron or TensorBoard to view the graph") | |
args = parser.parse_args() | |
def printTensors(graph): | |
""" | |
Print all of the operations in a TensorFlow graph | |
""" | |
for op in graph.get_operations(): | |
print(op.name) | |
def loadProtobuf(filename): | |
""" | |
Loads a binary TensorFlow protobuf file | |
""" | |
with tf.gfile.GFile(filename, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
tf.graph_util.remove_training_nodes(graph_def, protected_nodes=None) | |
return graph_def | |
graph_def = loadProtobuf(args.filename) | |
with tf.Graph().as_default() as graph: | |
tf.import_graph_def(graph_def) | |
#printTensors(graph) | |
x = graph.get_tensor_by_name(args.input_layer_name) | |
y = graph.get_tensor_by_name(args.output_layer_name) | |
sess = tf.Session(graph=graph) | |
print("Loaded graph {}".format(args.filename)) | |
tf.saved_model.simple_save(sess, | |
"saved_model_directory", | |
inputs={args.input_layer_name: x}, | |
outputs={args.output_layer_name: y}) | |
print("Saved model to 'saved_model_directory'") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment