Skip to content

Instantly share code, notes, and snippets.

@applenob
Created January 7, 2019 08:00
Show Gist options
  • Save applenob/f73971a0d77d156e0a5e8df7c56519bc to your computer and use it in GitHub Desktop.
Save applenob/f73971a0d77d156e0a5e8df7c56519bc to your computer and use it in GitHub Desktop.
freeze tensorflow checkpoint file to pb format.
"""
freeze tensorflow checkpoint file to pb format.
"""
import argparse
import os
import tensorflow as tf
import logging
from tensorflow.python.framework import graph_util
def freeze_graph(model_folder, output_node_names, pb_name="frozen_model.pb"):
# Retrieve checkpoint path and pb path
checkpoint = tf.train.get_checkpoint_state(model_folder)
checkpoint_path = checkpoint.model_checkpoint_path
absolute_model_folder = os.path.dirname(checkpoint_path)
pb_path = os.path.join(absolute_model_folder, pb_name)
# Import the meta graph and retrieve a Saver
# stripping device information from the graph
saver = tf.train.import_meta_graph(checkpoint_path + '.meta',
clear_devices=True)
# Retrieve the pb graph definition
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
# Start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, checkpoint_path)
# Export variables to constant
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names
)
# Serialize and dump the output graph to the filesystem
with tf.gfile.GFile(pb_path, "wb") as f:
f.write(output_graph_def.SerializeToString())
logging.info("{} ops in the final graph.".format(len(output_graph_def.node)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_folder", type=str,
help="Model folder to export.")
parser.add_argument("--output_names", type=str,
default="output/probs,attention_word_score,attention_sentence_score",
help="Output node names, separated by commas.")
args = parser.parse_args()
output_names = args.output_names.replace(" ", "").split(",")
freeze_graph(args.model_folder, output_names)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment