-
-
Save sunsided/88d24bf44068fe0fe5b88f09a1bee92a to your computer and use it in GitHub Desktop.
import argparse | |
import os | |
import sys | |
from typing import Iterable | |
import tensorflow as tf | |
parser = argparse.ArgumentParser() | |
parser.add_argument('file', type=str, help='The file name of the frozen graph.') | |
args = parser.parse_args() | |
if not os.path.exists(args.file): | |
parser.exit(1, 'The specified file does not exist: {}'.format(args.file)) | |
graph_def = None | |
graph = None | |
# Assuming a `.pb` file in `GraphDef` format. | |
# See comments on https://gist.github.com/sunsided/88d24bf44068fe0fe5b88f09a1bee92a/ | |
# for inspecting SavedModel graphs instead. | |
print('Loading graph definition ...', file=sys.stderr) | |
try: | |
with tf.gfile.GFile(args.file, "rb") as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
except BaseException as e: | |
parser.exit(2, 'Error loading the graph definition: {}'.format(str(e))) | |
print('Importing graph ...', file=sys.stderr) | |
try: | |
assert graph_def is not None | |
with tf.Graph().as_default() as graph: # type: tf.Graph | |
tf.import_graph_def( | |
graph_def, | |
input_map=None, | |
return_elements=None, | |
name='', | |
op_dict=None, | |
producer_op_list=None | |
) | |
except BaseException as e: | |
parser.exit(2, 'Error importing the graph: {}'.format(str(e))) | |
print() | |
print('Operations:') | |
assert graph is not None | |
ops = graph.get_operations() # type: Iterable[tf.Operation] | |
for op in ops: | |
print('- {0:20s} "{1}" ({2} outputs)'.format(op.type, op.name, len(op.outputs))) | |
print() | |
print('Sources (operations without inputs):') | |
for op in ops: | |
if len(op.inputs) > 0: | |
continue | |
print('- {0}'.format(op.name)) | |
print() | |
print('Operation inputs:') | |
for op in ops: | |
if len(op.inputs) == 0: | |
continue | |
print('- {0:20}'.format(op.name)) | |
print(' {0}'.format(', '.join(i.name for i in op.inputs))) | |
print() | |
print('Tensors:') | |
for op in ops: | |
for out in op.outputs: | |
print('- {0:20} {1:10} "{2}"'.format(str(out.shape), out.dtype.name, out.name)) |
peter197321
commented
Sep 9, 2020
via email
I just toyed around with it a bit and here is my best guess at what's happening. 🙂
First of all - to get rid of the warnings, try replacing import tensorflow as tf
with import tensorflow.compat.v1 as tf
.
I'm assuming the culprit is this: The code above is loading a frozen model containing a GraphDef
(defined in graph.proto) - however, these are not compatible with loading graphs stored as a SavedModel
(defined in saved_model.proto).
Specifically this block assumes the GraphDef
format, tries to decode it as such (this is where it blows up) and then imports the result into a new graph:
import tensorflow.compat.v1 as tf
model_file = "path/to/model/file.pb"
with tf.gfile.GFile(model_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
assert graph_def is not None
with tf.Graph().as_default() as graph: # type: tf.Graph
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name='',
op_dict=None,
producer_op_list=None
)
If you have a SavedModel
however, there's a much quicker way to achieve the same result using a saved_model.loader. Assuming the tf.tag_constants.SERVING
tag:
import tensorflow.compat.v1 as tf
model_path = "path/to/model" # <-- sneaky one, expects a `saved_model.pb` file in there
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tf.tag_constants.SERVING], model_path)
This method is deprecated according to the tf.compat.v1.saved_model.load documentation, but upgrading shouldn't be too hard.
Once you have the graph
populated with the mode, the rest of the code works as before.