-
-
Save haochunchang/f251deec78195e700865358a629e7798 to your computer and use it in GitHub Desktop.
Listing operations in frozen .pb TensorFlow graphs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import argparse, os, sys | |
import tensorflow as tf | |
def dump_graph_operations(filename): | |
graph_def = None | |
graph = None | |
print('Loading graph definition ...', file=sys.stderr) | |
try: | |
with tf.gfile.GFile(filename, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
except BaseException as e: | |
parser.exit(2, 'Error loading the graph definition: {}'.format(str(e))) | |
print('Importing graph ...', file=sys.stderr) | |
try: | |
assert graph_def is not None | |
with tf.Graph().as_default() as graph: # type: tf.Graph | |
tf.import_graph_def( | |
graph_def, | |
input_map=None, | |
return_elements=None, | |
name='', | |
op_dict=None, | |
producer_op_list=None | |
) | |
except BaseException as e: | |
parser.exit(2, 'Error importing the graph: {}'.format(str(e))) | |
print() | |
print('Operations:') | |
assert graph is not None | |
ops = graph.get_operations() | |
print_all_ops(ops) | |
print_sources(ops) | |
print_ops_inputs(ops) | |
print_all_tensors(ops) | |
def print_all_ops(ops): | |
for op in ops: | |
print('- {0:20s} "{1}" ({2} outputs)'.format(op.type, op.name, len(op.outputs))) | |
print() | |
def print_sources(ops): | |
print('Sources (operations without inputs):') | |
for op in ops: | |
if len(op.inputs) > 0: | |
continue | |
print('- {0}'.format(op.name)) | |
print() | |
def print_ops_inputs(ops): | |
print('Operation inputs:') | |
for op in ops: | |
if len(op.inputs) == 0: | |
continue | |
print('- {0:20}'.format(op.name)) | |
print(' {0}'.format(', '.join(i.name for i in op.inputs))) | |
print() | |
def print_all_tensors(ops): | |
print('Tensors:') | |
for op in ops: | |
for out in op.outputs: | |
print('- {0:20} {1:10} "{2}"'.format(str(out.shape), out.dtype.name, out.name)) | |
print() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('file', type=str, help='The file name of the frozen graph.') | |
args = parser.parse_args() | |
if not os.path.exists(args.file): | |
parser.exit(1, 'The specified file does not exist: {}'.format(args.file)) | |
dump_graph_operations(args.file) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment