-
-
Save sunsided/88d24bf44068fe0fe5b88f09a1bee92a to your computer and use it in GitHub Desktop.
import argparse | |
import os | |
import sys | |
from typing import Iterable | |
import tensorflow as tf | |
parser = argparse.ArgumentParser() | |
parser.add_argument('file', type=str, help='The file name of the frozen graph.') | |
args = parser.parse_args() | |
if not os.path.exists(args.file): | |
parser.exit(1, 'The specified file does not exist: {}'.format(args.file)) | |
graph_def = None | |
graph = None | |
# Assuming a `.pb` file in `GraphDef` format. | |
# See comments on https://gist.github.com/sunsided/88d24bf44068fe0fe5b88f09a1bee92a/ | |
# for inspecting SavedModel graphs instead. | |
print('Loading graph definition ...', file=sys.stderr) | |
try: | |
with tf.gfile.GFile(args.file, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
except BaseException as e: | |
parser.exit(2, 'Error loading the graph definition: {}'.format(str(e))) | |
print('Importing graph ...', file=sys.stderr) | |
try: | |
assert graph_def is not None | |
with tf.Graph().as_default() as graph: # type: tf.Graph | |
tf.import_graph_def( | |
graph_def, | |
input_map=None, | |
return_elements=None, | |
name='', | |
op_dict=None, | |
producer_op_list=None | |
) | |
except BaseException as e: | |
parser.exit(2, 'Error importing the graph: {}'.format(str(e))) | |
print() | |
print('Operations:') | |
assert graph is not None | |
ops = graph.get_operations() # type: Iterable[tf.Operation] | |
for op in ops: | |
print('- {0:20s} "{1}" ({2} outputs)'.format(op.type, op.name, len(op.outputs))) | |
print() | |
print('Sources (operations without inputs):') | |
for op in ops: | |
if len(op.inputs) > 0: | |
continue | |
print('- {0}'.format(op.name)) | |
print() | |
print('Operation inputs:') | |
for op in ops: | |
if len(op.inputs) == 0: | |
continue | |
print('- {0:20}'.format(op.name)) | |
print(' {0}'.format(', '.join(i.name for i in op.inputs))) | |
print() | |
print('Tensors:') | |
for op in ops: | |
for out in op.outputs: | |
print('- {0:20} {1:10} "{2}"'.format(str(out.shape), out.dtype.name, out.name)) |
What version of tf&python is needed?
(tf1.x-cpu) PS C:> python .\dump_operations.py .\saved_model\saved_model.pb
PS C:> python .\dump_operations.py .\saved_model\saved_model.pb
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Users\nxa18908\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorflow\python\framework\dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
ning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
_np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Anaconda3\envs\tf1.x-cpu\lib\site-packages\tensorboard\compat\tensorflow_stub\dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
np_resource = np.dtype([("resource", np.ubyte, 1)])
Loading graph definition ...
WARNING:tensorflow:From .\dump_operations.py:20: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.
WARNING:tensorflow:From .\dump_operations.py:21: The name tf.GraphDef is deprecated. Please use tf.compat.v1.GraphDef instead.
Error loading the graph definition: Wrong wire type in tag.
(tf1.x-cpu) PS C:> conda activate tf2.x-cpu
(tf2.x-cpu) PS C:> python .\dump_operations.py .\saved_model\saved_model.pb
Loading graph definition ...
Error loading the graph definition: module 'tensorflow' has no attribute 'gfile'
(tf2.x-cpu) PS C:>
It's been a while, but any Python 3 and some TensorFlow >= 1.8 and < 2 should do. Didn't try with TF 2, but its upgrade converter might help.
I just toyed around with it a bit and here is my best guess at what's happening. 🙂
First of all - to get rid of the warnings, try replacing import tensorflow as tf
with import tensorflow.compat.v1 as tf
.
I'm assuming the culprit is this: The code above is loading a frozen model containing a GraphDef
(defined in graph.proto) - however, these are not compatible with loading graphs stored as a SavedModel
(defined in saved_model.proto).
Specifically this block assumes the GraphDef
format, tries to decode it as such (this is where it blows up) and then imports the result into a new graph:
import tensorflow.compat.v1 as tf
model_file = "path/to/model/file.pb"
with tf.gfile.GFile(model_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
assert graph_def is not None
with tf.Graph().as_default() as graph: # type: tf.Graph
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name='',
op_dict=None,
producer_op_list=None
)
If you have a SavedModel
however, there's a much quicker way to achieve the same result using a saved_model.loader. Assuming the tf.tag_constants.SERVING
tag:
import tensorflow.compat.v1 as tf
model_path = "path/to/model" # <-- sneaky one, expects a `saved_model.pb` file in there
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tf.tag_constants.SERVING], model_path)
This method is deprecated according to the tf.compat.v1.saved_model.load documentation, but upgrading shouldn't be too hard.
Once you have the graph
populated with the mode, the rest of the code works as before.
Thank you mate!