Skip to content

Instantly share code, notes, and snippets.

@SuoXC
Last active December 14, 2018 03:09
Show Gist options
  • Save SuoXC/0ee494a339adb3d489b537527cd4d5dd to your computer and use it in GitHub Desktop.
Save SuoXC/0ee494a339adb3d489b537527cd4d5dd to your computer and use it in GitHub Desktop.
convert keras or tensorflow or tensorflow-keras model to a predictable saved_model for tensorflow-serving and python predictor
import tensorflow as tf
def export_submodel(sess, inputs, outputs, output_dir):
output_nodes = list(inputs.values()) + list(outputs.values())
output_node_names = [node.name.replace(":0", "") for node in output_nodes]
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
# sess = tf.get_default_session()
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
graph_def,
output_node_names
)
with tf.Graph().as_default():
with tf.Session().as_default():
# 必须在一个新的session和新的graph里操作
tf.import_graph_def(
output_graph_def,
input_map=None,
return_elements=None,
name="",
op_dict=None,
producer_op_list=None
)
builder = tf.saved_model.builder.SavedModelBuilder(output_dir)
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs={i: tf.saved_model.utils.build_tensor_info(j) for i, j in inputs.items()},
outputs={i: tf.saved_model.utils.build_tensor_info(j) for i, j in outputs.items()},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature,
})
builder.save()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment