Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active August 10, 2020 17:27
Show Gist options
  • Save gaphex/808153b72f683cd5a4a3513e6a854359 to your computer and use it in GitHub Desktop.
Save gaphex/808153b72f683cd5a4a3513e6a854359 to your computer and use it in GitHub Desktop.
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
def freeze_keras_model(model, export_path=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
@param model The Keras model to be optimized for inference.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
sess = tf.keras.backend.get_session()
graph = sess.graph
with graph.as_default():
input_tensors = model.inputs
output_tensors = model.outputs
dtypes = [t.dtype.as_datatype_enum for t in input_tensors]
input_ops = [t.name.rsplit(":", maxsplit=1)[0] for t in input_tensors]
output_ops = [t.name.rsplit(":", maxsplit=1)[0] for t in output_tensors]
tmp_g = graph.as_graph_def()
if clear_devices:
for node in tmp_g.node:
node.device = ""
tmp_g = optimize_for_inference(
tmp_g, input_ops, output_ops, dtypes, False)
tmp_g = convert_variables_to_constants(sess, tmp_g, output_ops)
if export_path is not None:
with tf.gfile.GFile(export_path, "wb") as f:
f.write(tmp_g.SerializeToString())
return tmp_g
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment