Created
January 7, 2019 08:00
-
-
Save applenob/f73971a0d77d156e0a5e8df7c56519bc to your computer and use it in GitHub Desktop.
freeze tensorflow checkpoint file to pb format.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
freeze tensorflow checkpoint file to pb format. | |
""" | |
import argparse | |
import os | |
import tensorflow as tf | |
import logging | |
from tensorflow.python.framework import graph_util | |
def freeze_graph(model_folder, output_node_names, pb_name="frozen_model.pb"): | |
# Retrieve checkpoint path and pb path | |
checkpoint = tf.train.get_checkpoint_state(model_folder) | |
checkpoint_path = checkpoint.model_checkpoint_path | |
absolute_model_folder = os.path.dirname(checkpoint_path) | |
pb_path = os.path.join(absolute_model_folder, pb_name) | |
# Import the meta graph and retrieve a Saver | |
# stripping device information from the graph | |
saver = tf.train.import_meta_graph(checkpoint_path + '.meta', | |
clear_devices=True) | |
# Retrieve the pb graph definition | |
graph = tf.get_default_graph() | |
input_graph_def = graph.as_graph_def() | |
# Start a session and restore the graph weights | |
with tf.Session() as sess: | |
saver.restore(sess, checkpoint_path) | |
# Export variables to constant | |
output_graph_def = graph_util.convert_variables_to_constants( | |
sess, | |
input_graph_def, | |
output_node_names | |
) | |
# Serialize and dump the output graph to the filesystem | |
with tf.gfile.GFile(pb_path, "wb") as f: | |
f.write(output_graph_def.SerializeToString()) | |
logging.info("{} ops in the final graph.".format(len(output_graph_def.node))) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model_folder", type=str, | |
help="Model folder to export.") | |
parser.add_argument("--output_names", type=str, | |
default="output/probs,attention_word_score,attention_sentence_score", | |
help="Output node names, separated by commas.") | |
args = parser.parse_args() | |
output_names = args.output_names.replace(" ", "").split(",") | |
freeze_graph(args.model_folder, output_names) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment