Skip to content

Instantly share code, notes, and snippets.

@haochunchang
Forked from sunsided/dump_operations.py
Last active November 23, 2018 02:55
Show Gist options
  • Save haochunchang/f251deec78195e700865358a629e7798 to your computer and use it in GitHub Desktop.
Save haochunchang/f251deec78195e700865358a629e7798 to your computer and use it in GitHub Desktop.
Listing operations in frozen .pb TensorFlow graphs
# -*- coding: utf-8 -*-
import argparse, os, sys
import tensorflow as tf
def dump_graph_operations(filename):
graph_def = None
graph = None
print('Loading graph definition ...', file=sys.stderr)
try:
with tf.gfile.GFile(filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
except BaseException as e:
parser.exit(2, 'Error loading the graph definition: {}'.format(str(e)))
print('Importing graph ...', file=sys.stderr)
try:
assert graph_def is not None
with tf.Graph().as_default() as graph: # type: tf.Graph
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name='',
op_dict=None,
producer_op_list=None
)
except BaseException as e:
parser.exit(2, 'Error importing the graph: {}'.format(str(e)))
print()
print('Operations:')
assert graph is not None
ops = graph.get_operations()
print_all_ops(ops)
print_sources(ops)
print_ops_inputs(ops)
print_all_tensors(ops)
def print_all_ops(ops):
for op in ops:
print('- {0:20s} "{1}" ({2} outputs)'.format(op.type, op.name, len(op.outputs)))
print()
def print_sources(ops):
print('Sources (operations without inputs):')
for op in ops:
if len(op.inputs) > 0:
continue
print('- {0}'.format(op.name))
print()
def print_ops_inputs(ops):
print('Operation inputs:')
for op in ops:
if len(op.inputs) == 0:
continue
print('- {0:20}'.format(op.name))
print(' {0}'.format(', '.join(i.name for i in op.inputs)))
print()
def print_all_tensors(ops):
print('Tensors:')
for op in ops:
for out in op.outputs:
print('- {0:20} {1:10} "{2}"'.format(str(out.shape), out.dtype.name, out.name))
print()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('file', type=str, help='The file name of the frozen graph.')
args = parser.parse_args()
if not os.path.exists(args.file):
parser.exit(1, 'The specified file does not exist: {}'.format(args.file))
dump_graph_operations(args.file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment