-
-
Save moodoki/e37a85fb0258b045c005ca3db9cbc7f6 to your computer and use it in GitHub Desktop.
import os, argparse | |
import tensorflow as tf | |
from tensorflow.python.framework import graph_util | |
dir = os.path.dirname(os.path.realpath(__file__)) | |
def freeze_graph(model_folder, output_nodes='y_hat', | |
output_filename='frozen-graph.pb', | |
rename_outputs=None): | |
#Load checkpoint | |
checkpoint = tf.train.get_checkpoint_state(model_folder) | |
input_checkpoint = checkpoint.model_checkpoint_path | |
output_graph = output_filename | |
#Devices should be cleared to allow Tensorflow to control placement of | |
#graph when loading on different machines | |
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', | |
clear_devices=True) | |
graph = tf.get_default_graph() | |
onames = output_nodes.split(',') | |
#https://stackoverflow.com/a/34399966/4190475 | |
if rename_outputs is not None: | |
nnames = rename_outputs.split(',') | |
with graph.as_default(): | |
for o, n in zip(onames, nnames): | |
_out = tf.identity(graph.get_tensor_by_name(o+':0'), name=n) | |
onames=nnames | |
input_graph_def = graph.as_graph_def() | |
# fix batch norm nodes | |
for node in input_graph_def.node: | |
if node.op == 'RefSwitch': | |
node.op = 'Switch' | |
for index in xrange(len(node.input)): | |
if 'moving_' in node.input[index]: | |
node.input[index] = node.input[index] + '/read' | |
elif node.op == 'AssignSub': | |
node.op = 'Sub' | |
if 'use_locking' in node.attr: del node.attr['use_locking'] | |
with tf.Session(graph=graph) as sess: | |
saver.restore(sess, input_checkpoint) | |
# In production, graph weights no longer need to be updated | |
# graph_util provides utility to change all variables to constants | |
output_graph_def = graph_util.convert_variables_to_constants( | |
sess, input_graph_def, | |
onames # unrelated nodes will be discarded | |
) | |
# Serialize and write to file | |
with tf.gfile.GFile(output_graph, "wb") as f: | |
f.write(output_graph_def.SerializeToString()) | |
print("%d ops in the final graph." % len(output_graph_def.node)) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser( | |
description='Prune and freeze weights from checkpoints into production models') | |
parser.add_argument("--checkpoint_path", | |
default='ckpt', | |
type=str, help="Path to checkpoint files") | |
parser.add_argument("--output_nodes", | |
default='y_hat', | |
type=str, help="Names of output node, comma seperated") | |
parser.add_argument("--output_graph", | |
default='frozen-graph.pb', | |
type=str, help="Output graph filename") | |
parser.add_argument("--rename_outputs", | |
default=None, | |
type=str, help="Rename output nodes for better \ | |
readability in production graph, to be specified in \ | |
the same order as output_nodes") | |
args = parser.parse_args() | |
freeze_graph(args.checkpoint_path, args.output_nodes, args.output_graph, args.rename_outputs) |
@szm2015 : you mean that you cannot do it for another pretrained model that is not in slim?
What do we put in --output_nodes
? Please help!
Hi, Thanks for sharing your implementation.
i use this code for convert checkpoint (.meta, .data and .index ) model files that result from run dcgan in https://github.com/carpedm20/DCGAN-tensorflow
into one graph.pb file .
please help me to Initialize output_nodes . run code and shows error like in the below.
WARNING:tensorflow:From C:\Users\Baran\Desktop\exportgraph.py:55: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.compat.v1.graph_util.convert_variables_to_constants
WARNING:tensorflow:From C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\graph_util_impl.py:245: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.compat.v1.graph_util.extract_sub_graph
Traceback (most recent call last):
File "C:\Users\Baran\Desktop\exportgraph.py", line 83, in
freeze_graph(args.checkpoint_path, args.output_nodes, args.output_graph, args.rename_outputs)
File "C:\Users\Baran\Desktop\exportgraph.py", line 55, in freeze_graph
onames # unrelated nodes will be discarded
File "C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\util\deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\graph_util_impl.py", line 245, in convert_variables_to_constants
inference_graph = extract_sub_graph(input_graph_def, output_node_names)
File "C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\util\deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\graph_util_impl.py", line 181, in extract_sub_graph
_assert_nodes_are_present(name_to_node, dest_nodes)
File "C:\Users\Baran\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\graph_util_impl.py", line 137, in _assert_nodes_are_present
assert d in name_to_node, "%s is not in graph" % d
AssertionError: y_hat is not in graph
You can use summarize_graph
to get input / output node names:
https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms#inspecting-graphs
AttributeError: 'NoneType' object has no attribute 'model_checkpoint_path'
Wow, somehow I didn't get any notifications on the comments made on this code snippet until the last comment by @charmby. Thanks, everyone that helped answer some queries by others :).
Tensorflow has evolved quite significantly since I shared this implementation and using tf.SavedModel might be a far easier approach for new code.
For working with old code/models, this script might still be useful. So let me answer some queries if they were not taken care of by others :)
@lihan, @selcouthlyBlue I did this mainly as it saves everything in a single file and makes it much easier for distributing. There should be some associated inference performance benefits as all model parameters are converted to constants instead of variables. The downside is that this saved model is no longer fine-tunable.
@charmby That error is due to the script not being able to find the model files in the path specified. The model path parameter should be the folder containing all the checkpoint files and not the .chkpt file itself.
@elham1992 This error is due to the graph not having a y_hat
node present. You may want to check if the output node name of the model that you are using.
@elham1992, @soufianesabiri, @achalshah20 Apart from using summarize_graph
as mentioned by others, if you have access to the code that was used to build the pre-trained graph, look at it and get it to print the output variable name. All TensorFlow graphs should work, slim isn't necessary. As a good practice, you may want to name important nodes by using the name=''
parameter when creating the graph, if this parameter isn't specified, TensorFlow as a default naming convention that appends a running count to each node of the same type, e.g. conv_0
, conv_1
, etc. Otherwise, there's also this tool that allows you to explore graphs in many different formats. https://github.com/lutzroeder/netron
@szm2015, can you let me know how does this approach different than the normal
bazel-bin/tensorflow/python/tools/freeze_graph
? Also, do you know if we can hack weights in a.pb
file and save it back to the frozen model? I don't seem to be able to make tf.Assign() to work on theseconst
nodes. Thanks