Created
November 8, 2018 16:13
-
-
Save wil3/38b9ff40c6012a229f711944b1e2e170 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, argparse | |
import tensorflow as tf | |
""" | |
This script converts a checkpoint to a pb file without needing to know | |
the names of the input and output nodes. This then allows you to use the | |
Tensorflow tool summarize_graph to identify potential input/output nodes. | |
Usage: | |
python3 checkpoint_to_pb.py checkpoints/-1900000 graphs/fullgraph.pb | |
For example, | |
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ | |
--in_graph="fullgraph.pb" | |
Tested with tensorflow==1.12.0 | |
You may get an error if the checkpoint was created with a different version. | |
""" | |
def checkpoint_to_pb(checkpoint_path, pb_file): | |
"""Convert the entire checkpoint to a pb file | |
Args: | |
checkpoint_path: Path to the checkpoint excluding the extension | |
pb_file: Path to the saved pb file | |
""" | |
meta_file = checkpoint_path + '.meta' | |
if not os.path.isfile(meta_file): | |
raise FileNotFoundError("Could not find checkpoint meta file.") | |
# We start a session using a temporary fresh Graph | |
with tf.Session(graph=tf.Graph()) as sess: | |
# We import the meta graph in the current default Graph | |
saver = tf.train.import_meta_graph(meta_file, clear_devices=True) | |
# We restore the weights | |
saver.restore(sess, checkpoint_path) | |
head, tail = os.path.split(pb_file) | |
tf.train.write_graph(tf.get_default_graph(), head, tail) | |
print ("PB file saved to {}".format(pb_file)) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("checkpoint", type=str, default=None, help="Checkpoint path excluding extension to be converted to a pb file") | |
parser.add_argument("pb", type=str, default=None, help="Output pb file") | |
args = parser.parse_args() | |
checkpoint_to_pb(args.checkpoint, args.pb) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment