Skip to content

Instantly share code, notes, and snippets.

@moodoki
Last active February 14, 2023 05:58
Show Gist options
  • Save moodoki/e37a85fb0258b045c005ca3db9cbc7f6 to your computer and use it in GitHub Desktop.
Save moodoki/e37a85fb0258b045c005ca3db9cbc7f6 to your computer and use it in GitHub Desktop.
Freeze and export Tensorflow graph from checkpoint files
import os, argparse
import tensorflow as tf
from tensorflow.python.framework import graph_util
dir = os.path.dirname(os.path.realpath(__file__))
def freeze_graph(model_folder, output_nodes='y_hat',
output_filename='frozen-graph.pb',
rename_outputs=None):
#Load checkpoint
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
output_graph = output_filename
#Devices should be cleared to allow Tensorflow to control placement of
#graph when loading on different machines
saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
clear_devices=True)
graph = tf.get_default_graph()
onames = output_nodes.split(',')
#https://stackoverflow.com/a/34399966/4190475
if rename_outputs is not None:
nnames = rename_outputs.split(',')
with graph.as_default():
for o, n in zip(onames, nnames):
_out = tf.identity(graph.get_tensor_by_name(o+':0'), name=n)
onames=nnames
input_graph_def = graph.as_graph_def()
# fix batch norm nodes
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
with tf.Session(graph=graph) as sess:
saver.restore(sess, input_checkpoint)
# In production, graph weights no longer need to be updated
# graph_util provides utility to change all variables to constants
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def,
onames # unrelated nodes will be discarded
)
# Serialize and write to file
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Prune and freeze weights from checkpoints into production models')
parser.add_argument("--checkpoint_path",
default='ckpt',
type=str, help="Path to checkpoint files")
parser.add_argument("--output_nodes",
default='y_hat',
type=str, help="Names of output node, comma seperated")
parser.add_argument("--output_graph",
default='frozen-graph.pb',
type=str, help="Output graph filename")
parser.add_argument("--rename_outputs",
default=None,
type=str, help="Rename output nodes for better \
readability in production graph, to be specified in \
the same order as output_nodes")
args = parser.parse_args()
freeze_graph(args.checkpoint_path, args.output_nodes, args.output_graph, args.rename_outputs)
@szm-R
Copy link

szm-R commented May 17, 2018

Same question as @achalshah20 , I know one way to do this which is summarize_graph but that requires the .pb model (or the text version .pbtxt) to start with, which is exactly what we're trying to create here, So?

@szm-R
Copy link

szm-R commented May 17, 2018

Hi, I write this so that it may help some confused person like me! In order to get input/output nodes you can summarize graph by issuing the following command from tensorflow root:
bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=my_graph.pb

Now, as you see you need a pb file here. This pb is just the model's architecture (not the frozen model) and can be created from checkpoint files with this:
python3 export_inference_graph.py --model_name=mobilenet_v1 --output_file=unfrozen_graph.pb

However, the drawback here is that this code, which is located in tensorflow-models/research/slim only works with a handful of predefined models. You can see the list of these models in tensorflow-models/research/slim/nets/nets_factory.py. After determining the output layer you can freeze the model using tensorflow/python/tools/freeze_graph.py.

Hope it helps!

@amirjamez
Copy link

@szm2015, can you let me know how does this approach different than the normal bazel-bin/tensorflow/python/tools/freeze_graph? Also, do you know if we can hack weights in a .pb file and save it back to the frozen model? I don't seem to be able to make tf.Assign() to work on these const nodes. Thanks

@estelleaf
Copy link

@szm2015 : you mean that you cannot do it for another pretrained model that is not in slim?

@soufianesabiri
Copy link

What do we put in --output_nodes ? Please help!

@elham1992
Copy link

Hi, Thanks for sharing your implementation.
i use this code for convert checkpoint (.meta, .data and .index ) model files that result from run dcgan in https://github.com/carpedm20/DCGAN-tensorflow
into one graph.pb file .
please help me to Initialize output_nodes . run code and shows error like in the below.

WARNING:tensorflow:From C:\Users\Baran\Desktop\exportgraph.py:55: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.compat.v1.graph_util.convert_variables_to_constants
WARNING:tensorflow:From C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\graph_util_impl.py:245: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.compat.v1.graph_util.extract_sub_graph
Traceback (most recent call last):
File "C:\Users\Baran\Desktop\exportgraph.py", line 83, in
freeze_graph(args.checkpoint_path, args.output_nodes, args.output_graph, args.rename_outputs)
File "C:\Users\Baran\Desktop\exportgraph.py", line 55, in freeze_graph
onames # unrelated nodes will be discarded
File "C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\util\deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\graph_util_impl.py", line 245, in convert_variables_to_constants
inference_graph = extract_sub_graph(input_graph_def, output_node_names)
File "C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\util\deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\graph_util_impl.py", line 181, in extract_sub_graph
_assert_nodes_are_present(name_to_node, dest_nodes)
File "C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\graph_util_impl.py", line 137, in _assert_nodes_are_present
assert d in name_to_node, "%s is not in graph" % d
AssertionError: y_hat is not in graph

@mrgloom
Copy link

mrgloom commented May 24, 2019

@charmby
Copy link

charmby commented Jun 12, 2020

AttributeError: 'NoneType' object has no attribute 'model_checkpoint_path'

@moodoki
Copy link
Author

moodoki commented Jun 16, 2020

Wow, somehow I didn't get any notifications on the comments made on this code snippet until the last comment by @charmby. Thanks, everyone that helped answer some queries by others :).

Tensorflow has evolved quite significantly since I shared this implementation and using tf.SavedModel might be a far easier approach for new code.

For working with old code/models, this script might still be useful. So let me answer some queries if they were not taken care of by others :)

@lihan, @selcouthlyBlue I did this mainly as it saves everything in a single file and makes it much easier for distributing. There should be some associated inference performance benefits as all model parameters are converted to constants instead of variables. The downside is that this saved model is no longer fine-tunable.

@charmby That error is due to the script not being able to find the model files in the path specified. The model path parameter should be the folder containing all the checkpoint files and not the .chkpt file itself.

@elham1992 This error is due to the graph not having a y_hat node present. You may want to check if the output node name of the model that you are using.

@elham1992, @soufianesabiri, @achalshah20 Apart from using summarize_graph as mentioned by others, if you have access to the code that was used to build the pre-trained graph, look at it and get it to print the output variable name. All TensorFlow graphs should work, slim isn't necessary. As a good practice, you may want to name important nodes by using the name='' parameter when creating the graph, if this parameter isn't specified, TensorFlow as a default naming convention that appends a running count to each node of the same type, e.g. conv_0, conv_1, etc. Otherwise, there's also this tool that allows you to explore graphs in many different formats. https://github.com/lutzroeder/netron

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment