Skip to content

Instantly share code, notes, and snippets.

@sunsided
Last active October 27, 2021 21:26
Show Gist options
  • Save sunsided/88d24bf44068fe0fe5b88f09a1bee92a to your computer and use it in GitHub Desktop.
Save sunsided/88d24bf44068fe0fe5b88f09a1bee92a to your computer and use it in GitHub Desktop.
Listing operations in frozen .pb TensorFlow graphs in GraphDef format (see comments for SavedModel)
import argparse
import os
import sys
from typing import Iterable
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument('file', type=str, help='The file name of the frozen graph.')
args = parser.parse_args()
if not os.path.exists(args.file):
parser.exit(1, 'The specified file does not exist: {}'.format(args.file))
graph_def = None
graph = None
# Assuming a `.pb` file in `GraphDef` format.
# See comments on https://gist.github.com/sunsided/88d24bf44068fe0fe5b88f09a1bee92a/
# for inspecting SavedModel graphs instead.
print('Loading graph definition ...', file=sys.stderr)
try:
with tf.gfile.GFile(args.file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
except BaseException as e:
parser.exit(2, 'Error loading the graph definition: {}'.format(str(e)))
print('Importing graph ...', file=sys.stderr)
try:
assert graph_def is not None
with tf.Graph().as_default() as graph: # type: tf.Graph
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name='',
op_dict=None,
producer_op_list=None
)
except BaseException as e:
parser.exit(2, 'Error importing the graph: {}'.format(str(e)))
print()
print('Operations:')
assert graph is not None
ops = graph.get_operations() # type: Iterable[tf.Operation]
for op in ops:
print('- {0:20s} "{1}" ({2} outputs)'.format(op.type, op.name, len(op.outputs)))
print()
print('Sources (operations without inputs):')
for op in ops:
if len(op.inputs) > 0:
continue
print('- {0}'.format(op.name))
print()
print('Operation inputs:')
for op in ops:
if len(op.inputs) == 0:
continue
print('- {0:20}'.format(op.name))
print(' {0}'.format(', '.join(i.name for i in op.inputs)))
print()
print('Tensors:')
for op in ops:
for out in op.outputs:
print('- {0:20} {1:10} "{2}"'.format(str(out.shape), out.dtype.name, out.name))
@satyajithj
Copy link

Thank you mate!

@peter197321
Copy link

peter197321 commented Sep 9, 2020

What version of tf&python is needed?

(tf1.x-cpu) PS C:> python .\dump_operations.py .\saved_model\saved_model.pb
PS C:> python .\dump_operations.py .\saved_model\saved_model.pb
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Users\nxa18908\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
ning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
Loading graph definition ...
WARNING:tensorflow:From .\dump_operations.py:20: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.

WARNING:tensorflow:From .\dump_operations.py:21: The name tf.GraphDef is deprecated. Please use tf.compat.v1.GraphDef instead.

Error loading the graph definition: Wrong wire type in tag.
(tf1.x-cpu) PS C:> conda activate tf2.x-cpu
(tf2.x-cpu) PS C:> python .\dump_operations.py .\saved_model\saved_model.pb
Loading graph definition ...
Error loading the graph definition: module 'tensorflow' has no attribute 'gfile'
(tf2.x-cpu) PS C:>

@sunsided
Copy link
Author

sunsided commented Sep 9, 2020

It's been a while, but any Python 3 and some TensorFlow >= 1.8 and < 2 should do. Didn't try with TF 2, but its upgrade converter might help.

@peter197321
Copy link

peter197321 commented Sep 9, 2020 via email

@sunsided
Copy link
Author

sunsided commented Sep 9, 2020

I just toyed around with it a bit and here is my best guess at what's happening. 🙂

First of all - to get rid of the warnings, try replacing import tensorflow as tf with import tensorflow.compat.v1 as tf.

I'm assuming the culprit is this: The code above is loading a frozen model containing a GraphDef (defined in graph.proto) - however, these are not compatible with loading graphs stored as a SavedModel (defined in saved_model.proto).

Specifically this block assumes the GraphDef format, tries to decode it as such (this is where it blows up) and then imports the result into a new graph:

import tensorflow.compat.v1 as tf

model_file = "path/to/model/file.pb"

with tf.gfile.GFile(model_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

assert graph_def is not None
with tf.Graph().as_default() as graph:  # type: tf.Graph
    tf.import_graph_def(
        graph_def,
        input_map=None,
        return_elements=None,
        name='',
        op_dict=None,
        producer_op_list=None
    )

If you have a SavedModel however, there's a much quicker way to achieve the same result using a saved_model.loader. Assuming the tf.tag_constants.SERVING tag:

import tensorflow.compat.v1 as tf

model_path = "path/to/model"  # <-- sneaky one, expects a `saved_model.pb` file in there

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
    tf.saved_model.loader.load(sess, [tf.tag_constants.SERVING], model_path)

This method is deprecated according to the tf.compat.v1.saved_model.load documentation, but upgrading shouldn't be too hard.

Once you have the graph populated with the mode, the rest of the code works as before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment