Skip to content

Instantly share code, notes, and snippets.

@wil3
Created November 8, 2018 16:13
Show Gist options
  • Save wil3/38b9ff40c6012a229f711944b1e2e170 to your computer and use it in GitHub Desktop.
Save wil3/38b9ff40c6012a229f711944b1e2e170 to your computer and use it in GitHub Desktop.
import os, argparse
import tensorflow as tf
"""
This script converts a checkpoint to a pb file without needing to know
the names of the input and output nodes. This then allows you to use the
Tensorflow tool summarize_graph to identify potential input/output nodes.
Usage:
python3 checkpoint_to_pb.py checkpoints/-1900000 graphs/fullgraph.pb
For example,
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph="fullgraph.pb"
Tested with tensorflow==1.12.0
You may get an error if the checkpoint was created with a different version.
"""
def checkpoint_to_pb(checkpoint_path, pb_file):
"""Convert the entire checkpoint to a pb file
Args:
checkpoint_path: Path to the checkpoint excluding the extension
pb_file: Path to the saved pb file
"""
meta_file = checkpoint_path + '.meta'
if not os.path.isfile(meta_file):
raise FileNotFoundError("Could not find checkpoint meta file.")
# We start a session using a temporary fresh Graph
with tf.Session(graph=tf.Graph()) as sess:
# We import the meta graph in the current default Graph
saver = tf.train.import_meta_graph(meta_file, clear_devices=True)
# We restore the weights
saver.restore(sess, checkpoint_path)
head, tail = os.path.split(pb_file)
tf.train.write_graph(tf.get_default_graph(), head, tail)
print ("PB file saved to {}".format(pb_file))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("checkpoint", type=str, default=None, help="Checkpoint path excluding extension to be converted to a pb file")
parser.add_argument("pb", type=str, default=None, help="Output pb file")
args = parser.parse_args()
checkpoint_to_pb(args.checkpoint, args.pb)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment