Skip to content

Instantly share code, notes, and snippets.

@zhanwenchen
Last active November 4, 2024 16:26
Show Gist options
  • Save zhanwenchen/d628ef70e9f76525fd47d6213c30730d to your computer and use it in GitHub Desktop.
Save zhanwenchen/d628ef70e9f76525fd47d6213c30730d to your computer and use it in GitHub Desktop.
Minimal code to load a trained TensorFlow model from a checkpoint and export it with SavedModelBuilder
import os
import tensorflow as tf
trained_checkpoint_prefix = 'checkpoints/dev'
export_dir = os.path.join('models', '0') # IMPORTANT: each model folder must be named '0', '1', ... Otherwise it will fail!
loaded_graph = tf.Graph()
with tf.Session(graph=loaded_graph) as sess:
# Restore from checkpoint
loader = tf.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
loader.restore(sess, trained_checkpoint_prefix)
# Export checkpoint to SavedModel
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess,
[tf.saved_model.tag_constants.TRAINING, tf.saved_model.tag_constants.SERVING],
strip_default_attrs=True)
builder.save()
@Jonarod
Copy link

Jonarod commented Aug 6, 2019

Thanks !
Using compat for TensorFlow 2.0:

import os
import tensorflow as tf

# trained_checkpoint_prefix = 'checkpoints/dev'
trained_checkpoint_prefix = 'model.ckpt'
export_dir = os.path.join('models', '0') # IMPORTANT: each model folder must be named '0', '1', ... Otherwise it will fail!

loaded_graph = tf.Graph()
with tf.compat.v1.Session(graph=loaded_graph) as sess:
    # Restore from checkpoint
    loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
    loader.restore(sess, trained_checkpoint_prefix)
    
    # Export checkpoint to SavedModel
    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
    builder.add_meta_graph_and_variables(sess, ["train", "serve"], strip_default_attrs=True)
    builder.save()

@Nerdyvedi
Copy link

I am getting a ValueError: Atleast 2 variables have the same name

Do you have an idea on how to solve this ?

@372046933
Copy link

signature_def_map is lost in this way

@Litchilitchy
Copy link

signature_def_map is lost in this way

Yes, and this is needed for serving. How to keep this parameter?

@Aroueterra
Copy link

Any resolution to this? @Litchilitchy

@Litchilitchy
Copy link

@Aroueterra You have to add the input and output by yourself.

According to Tensorflow, it is like you must have an entry point. Although it make sense that the model to save already include this, like if a Keras model we could just use model.inputs, model.outputs.

Is it possible to directly get it from the code from .ckpt and .meta?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment