Last active
October 27, 2021 21:26
-
-
Save sunsided/88d24bf44068fe0fe5b88f09a1bee92a to your computer and use it in GitHub Desktop.
Listing operations in frozen .pb TensorFlow graphs in GraphDef format (see comments for SavedModel)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import sys | |
from typing import Iterable | |
import tensorflow as tf | |
parser = argparse.ArgumentParser() | |
parser.add_argument('file', type=str, help='The file name of the frozen graph.') | |
args = parser.parse_args() | |
if not os.path.exists(args.file): | |
parser.exit(1, 'The specified file does not exist: {}'.format(args.file)) | |
graph_def = None | |
graph = None | |
# Assuming a `.pb` file in `GraphDef` format. | |
# See comments on https://gist.github.com/sunsided/88d24bf44068fe0fe5b88f09a1bee92a/ | |
# for inspecting SavedModel graphs instead. | |
print('Loading graph definition ...', file=sys.stderr) | |
try: | |
with tf.gfile.GFile(args.file, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
except BaseException as e: | |
parser.exit(2, 'Error loading the graph definition: {}'.format(str(e))) | |
print('Importing graph ...', file=sys.stderr) | |
try: | |
assert graph_def is not None | |
with tf.Graph().as_default() as graph: # type: tf.Graph | |
tf.import_graph_def( | |
graph_def, | |
input_map=None, | |
return_elements=None, | |
name='', | |
op_dict=None, | |
producer_op_list=None | |
) | |
except BaseException as e: | |
parser.exit(2, 'Error importing the graph: {}'.format(str(e))) | |
print() | |
print('Operations:') | |
assert graph is not None | |
ops = graph.get_operations() # type: Iterable[tf.Operation] | |
for op in ops: | |
print('- {0:20s} "{1}" ({2} outputs)'.format(op.type, op.name, len(op.outputs))) | |
print() | |
print('Sources (operations without inputs):') | |
for op in ops: | |
if len(op.inputs) > 0: | |
continue | |
print('- {0}'.format(op.name)) | |
print() | |
print('Operation inputs:') | |
for op in ops: | |
if len(op.inputs) == 0: | |
continue | |
print('- {0:20}'.format(op.name)) | |
print(' {0}'.format(', '.join(i.name for i in op.inputs))) | |
print() | |
print('Tensors:') | |
for op in ops: | |
for out in op.outputs: | |
print('- {0:20} {1:10} "{2}"'.format(str(out.shape), out.dtype.name, out.name)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I just toyed around with it a bit and here is my best guess at what's happening. 🙂
First of all - to get rid of the warnings, try replacing
import tensorflow as tf
withimport tensorflow.compat.v1 as tf
.I'm assuming the culprit is this: The code above is loading a frozen model containing a
GraphDef
(defined in graph.proto) - however, these are not compatible with loading graphs stored as aSavedModel
(defined in saved_model.proto).Specifically this block assumes the
GraphDef
format, tries to decode it as such (this is where it blows up) and then imports the result into a new graph:If you have a
SavedModel
however, there's a much quicker way to achieve the same result using a saved_model.loader. Assuming thetf.tag_constants.SERVING
tag:This method is deprecated according to the tf.compat.v1.saved_model.load documentation, but upgrading shouldn't be too hard.
Once you have the
graph
populated with the mode, the rest of the code works as before.