Skip to content

Instantly share code, notes, and snippets.

@applenob
Last active January 7, 2019 08:25
Show Gist options
  • Save applenob/1558c8ed159bc03b3981f6e2f5e98f4a to your computer and use it in GitHub Desktop.
Save applenob/1558c8ed159bc03b3981f6e2f5e98f4a to your computer and use it in GitHub Desktop.
Save a tensorflow model to a pb file.
# coding=utf-8
"""Save a tensorflow model to a pb file."""
# Build the model, then train it or load weights from somewhere else.
# ...
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
print("Node num before freeze: ", len(input_graph_def.node))
from tensorflow.python.framework import graph_util
# We use a built-in TF helper to export variables to constant
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
["input", "output"] # We split on comma for convenience
)
print("Node num after freeze: ", len(output_graph_def.node))
# save to pb file
pb_file = "pbfile.pb"
with tf.gfile.GFile(pb_file, "wb") as f:
f.write(output_graph_def.SerializeToString())
print(f"Save to {pb_file}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment