Skip to content

Instantly share code, notes, and snippets.

@sunsided
Last active October 27, 2021 21:26
Show Gist options
  • Save sunsided/88d24bf44068fe0fe5b88f09a1bee92a to your computer and use it in GitHub Desktop.
Save sunsided/88d24bf44068fe0fe5b88f09a1bee92a to your computer and use it in GitHub Desktop.
Listing operations in frozen .pb TensorFlow graphs in GraphDef format (see comments for SavedModel)
import argparse
import os
import sys
from typing import Iterable
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument('file', type=str, help='The file name of the frozen graph.')
args = parser.parse_args()
if not os.path.exists(args.file):
parser.exit(1, 'The specified file does not exist: {}'.format(args.file))
graph_def = None
graph = None
# Assuming a `.pb` file in `GraphDef` format.
# See comments on https://gist.github.com/sunsided/88d24bf44068fe0fe5b88f09a1bee92a/
# for inspecting SavedModel graphs instead.
print('Loading graph definition ...', file=sys.stderr)
try:
with tf.gfile.GFile(args.file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
except BaseException as e:
parser.exit(2, 'Error loading the graph definition: {}'.format(str(e)))
print('Importing graph ...', file=sys.stderr)
try:
assert graph_def is not None
with tf.Graph().as_default() as graph: # type: tf.Graph
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name='',
op_dict=None,
producer_op_list=None
)
except BaseException as e:
parser.exit(2, 'Error importing the graph: {}'.format(str(e)))
print()
print('Operations:')
assert graph is not None
ops = graph.get_operations() # type: Iterable[tf.Operation]
for op in ops:
print('- {0:20s} "{1}" ({2} outputs)'.format(op.type, op.name, len(op.outputs)))
print()
print('Sources (operations without inputs):')
for op in ops:
if len(op.inputs) > 0:
continue
print('- {0}'.format(op.name))
print()
print('Operation inputs:')
for op in ops:
if len(op.inputs) == 0:
continue
print('- {0:20}'.format(op.name))
print(' {0}'.format(', '.join(i.name for i in op.inputs)))
print()
print('Tensors:')
for op in ops:
for out in op.outputs:
print('- {0:20} {1:10} "{2}"'.format(str(out.shape), out.dtype.name, out.name))
@sunsided
Copy link
Author

sunsided commented Sep 9, 2020

It's been a while, but any Python 3 and some TensorFlow >= 1.8 and < 2 should do. Didn't try with TF 2, but its upgrade converter might help.

@peter197321
Copy link

peter197321 commented Sep 9, 2020 via email

@sunsided
Copy link
Author

sunsided commented Sep 9, 2020

I just toyed around with it a bit and here is my best guess at what's happening. 🙂

First of all - to get rid of the warnings, try replacing import tensorflow as tf with import tensorflow.compat.v1 as tf.

I'm assuming the culprit is this: The code above is loading a frozen model containing a GraphDef (defined in graph.proto) - however, these are not compatible with loading graphs stored as a SavedModel (defined in saved_model.proto).

Specifically this block assumes the GraphDef format, tries to decode it as such (this is where it blows up) and then imports the result into a new graph:

import tensorflow.compat.v1 as tf

model_file = "path/to/model/file.pb"

with tf.gfile.GFile(model_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

assert graph_def is not None
with tf.Graph().as_default() as graph:  # type: tf.Graph
    tf.import_graph_def(
        graph_def,
        input_map=None,
        return_elements=None,
        name='',
        op_dict=None,
        producer_op_list=None
    )

If you have a SavedModel however, there's a much quicker way to achieve the same result using a saved_model.loader. Assuming the tf.tag_constants.SERVING tag:

import tensorflow.compat.v1 as tf

model_path = "path/to/model"  # <-- sneaky one, expects a `saved_model.pb` file in there

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
    tf.saved_model.loader.load(sess, [tf.tag_constants.SERVING], model_path)

This method is deprecated according to the tf.compat.v1.saved_model.load documentation, but upgrading shouldn't be too hard.

Once you have the graph populated with the mode, the rest of the code works as before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment