Created
November 5, 2019 23:30
-
-
Save tonyreina/89a284b7cb6441a4afc3a4bbefd05199 to your computer and use it in GitHub Desktop.
Summarize TensorFlow Graph for Inputs and Outputs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import tensorflow as tf | |
import os | |
import sys | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert'] | |
def dump_for_tensorboard(graph_def: tf.GraphDef, logdir: str): | |
pass | |
try: | |
# TODO: graph_def is a deprecated argument, use graph instead | |
print('Writing an event file for the tensorboard...') | |
with tf.summary.FileWriter(logdir=logdir, graph_def=graph_def) as writer: | |
writer.flush() | |
print('Done writing an event file.') | |
except Exception as err: | |
raise Error('Cannot write an event file for the tensorboard to directory "{}". ' + | |
refer_to_faq_msg(36), logdir) from err | |
def children(op_name: str, graph: tf.Graph): | |
op = graph.get_operation_by_name(op_name) | |
return set(op for out in op.outputs for op in out.consumers()) | |
def summarize_graph(graph_def): | |
placeholders = dict() | |
outputs = list() | |
graph = tf.Graph() | |
with graph.as_default(): | |
tf.import_graph_def(graph_def, name='') | |
for node in graph.as_graph_def().node: | |
if node.op == 'Placeholder': | |
node_dict = dict() | |
node_dict['type'] = tf.DType(node.attr['dtype'].type).name | |
node_dict['shape'] = str(tf.TensorShape(node.attr['shape'].shape)).replace(' ', '').replace('?', '-1') | |
placeholders[node.name] = node_dict | |
if len(children(node.name, graph)) == 0: | |
if node.op not in unlikely_output_types and node.name.split('/')[-1] not in unlikely_output_types: | |
outputs.append(node.name) | |
result = dict() | |
result['inputs'] = placeholders | |
result['outputs'] = outputs | |
return result | |
def print_summary(summary): | |
print('------------') | |
print('{} input(s) detected:'.format(len(summary['inputs']))) | |
for input in summary['inputs']: | |
print("Name: {}, type: {}, shape: {}".format(input, summary['inputs'][input]['type'], | |
summary['inputs'][input]['shape'])) | |
print('------------') | |
print('{} output(s) detected:'.format(len(summary['outputs']))) | |
for output in summary['outputs']: | |
print('Name: %s' % output) | |
print('') | |
def main(): | |
parser = argparse.ArgumentParser(description='Freeze saved model') | |
parser.add_argument('--model', type=str, help='Path to TF model folder', required=True) | |
parser.add_argument('--output', type=str, help='Output layer name', required=False) | |
parser.add_argument('--summary', type=bool, help='Summarize only', required=False) | |
parser.add_argument('--logs', type=bool, help='Dump logs for tensorboard', required=False) | |
args = parser.parse_args() | |
model_folder = args.model | |
summarize = args.summary | |
output = args.output | |
logs = args.logs | |
session_config = tf.ConfigProto(allow_soft_placement=True) | |
with tf.Session(config=session_config) as sess: | |
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_folder) | |
[print(n.name) for n in tf.get_default_graph().as_graph_def().node] | |
if summarize: | |
summary = summarize_graph(sess.graph_def) | |
print_summary(summary) | |
else: | |
if not output: | |
print('Please provide output layer name') | |
return | |
# Freeze the graph | |
frozen_graph_def = tf.graph_util.convert_variables_to_constants( | |
sess, | |
sess.graph_def, | |
output.split(',')) | |
# Save the frozen graph | |
with open('frozen.pb', 'wb') as f: | |
f.write(frozen_graph_def.SerializeToString()) | |
summary = summarize_graph(frozen_graph_def) | |
print_summary(summary) | |
if logs: | |
dump_for_tensorboard(frozen_graph_def, 'logs') | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment