Skip to content

Instantly share code, notes, and snippets.

@vabarbosa
Created November 1, 2018 17:33
Show Gist options
  • Save vabarbosa/1e27ae6917bc96ffb8112f4927720612 to your computer and use it in GitHub Desktop.
Save vabarbosa/1e27ae6917bc96ffb8112f4927720612 to your computer and use it in GitHub Desktop.
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
import tensorflow as tf
# set the appropriate input and output nodes
input_node_names = ['decode/DecodeJpeg']
output_node_names = ['softmax']
# set the appropriate path to the frozen graph and directory to output stripped graph
frozen_graph_path = '/Users/va/models/frozen_graph.pb'
frozen_graph_stripped_path = '/Users/va/models/stripped_graph'
# load the frozen file and parse it to get the unserialized graph_def
restored_graph_def = None
with tf.gfile.GFile(frozen_graph_path, "rb") as f:
restored_graph_def = tf.GraphDef()
restored_graph_def.ParseFromString(f.read())
gdef = strip_unused_lib.strip_unused(
input_graph_def = restored_graph_def,
input_node_names = input_node_names,
output_node_names = output_node_names,
placeholder_type_enum = dtypes.uint8.as_datatype_enum)
with tf.gfile.GFile(frozen_graph_stripped_path, "wb") as f:
f.write(gdef.SerializeToString())
print("Stripped frozen graph file: {}".format(frozen_graph_stripped_path))
print(" File size: {} MiB".format(os.path.getsize(frozen_graph_stripped_path) >> 20))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment