Last active
November 4, 2024 16:26
-
-
Save zhanwenchen/d628ef70e9f76525fd47d6213c30730d to your computer and use it in GitHub Desktop.
Minimal code to load a trained TensorFlow model from a checkpoint and export it with SavedModelBuilder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import tensorflow as tf | |
trained_checkpoint_prefix = 'checkpoints/dev' | |
export_dir = os.path.join('models', '0') # IMPORTANT: each model folder must be named '0', '1', ... Otherwise it will fail! | |
loaded_graph = tf.Graph() | |
with tf.Session(graph=loaded_graph) as sess: | |
# Restore from checkpoint | |
loader = tf.train.import_meta_graph(trained_checkpoint_prefix + '.meta') | |
loader.restore(sess, trained_checkpoint_prefix) | |
# Export checkpoint to SavedModel | |
builder = tf.saved_model.builder.SavedModelBuilder(export_dir) | |
builder.add_meta_graph_and_variables(sess, | |
[tf.saved_model.tag_constants.TRAINING, tf.saved_model.tag_constants.SERVING], | |
strip_default_attrs=True) | |
builder.save() |
I am getting a ValueError: Atleast 2 variables have the same name
Do you have an idea on how to solve this ?
signature_def_map
is lost in this way
signature_def_map
is lost in this way
Yes, and this is needed for serving. How to keep this parameter?
Any resolution to this? @Litchilitchy
@Aroueterra You have to add the input and output by yourself.
According to Tensorflow, it is like you must have an entry point. Although it make sense that the model to save already include this, like if a Keras model we could just use model.inputs, model.outputs
.
Is it possible to directly get it from the code from .ckpt and .meta?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks !
Using
compat
for TensorFlow 2.0: