Last active
June 12, 2017 08:00
-
-
Save mohapatras/9abafffb6bd2c42cb4645bf782758ed6 to your computer and use it in GitHub Desktop.
Found in a comment by user @tspthomas.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| I managed to export a Keras model for Tensorflow Serving (not sure whether it is the official way to do | |
| this). My first trial prior to creating my custom model was to use a trained model available on | |
| Keras such as VGG19. | |
| Here is how I did (I put in separate boxes to help understanding and because I use Jupyter :)): | |
| ''' | |
| #Creating the model | |
| import keras.backend as K | |
| from keras.applications import VGG19 | |
| from keras.models import Model | |
| # very important to do this as a first thing | |
| K.set_learning_phase(0) | |
| model = VGG19(include_top=True, weights='imagenet') | |
| # The creation of a new model might be optional depending on the goal | |
| config = model.get_config() | |
| weights = model.get_weights() | |
| new_model = Model.from_config(config) | |
| new_model.set_weights(weights) | |
| Exporting the model | |
| from tensorflow.python.saved_model import builder as saved_model_builder | |
| from tensorflow.python.saved_model import utils | |
| from tensorflow.python.saved_model import tag_constants, signature_constants | |
| from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def, predict_signature_def | |
| from tensorflow.contrib.session_bundle import exporter | |
| export_path = 'folder_to_export' | |
| builder = saved_model_builder.SavedModelBuilder(export_path) | |
| signature = predict_signature_def(inputs={'images': new_model.input}, | |
| outputs={'scores': new_model.output}) | |
| with K.get_session() as sess: | |
| builder.add_meta_graph_and_variables(sess=sess, | |
| tags=[tag_constants.SERVING], | |
| signature_def_map={'predict': signature}) | |
| builder.save() | |
| ''' | |
| Some side notes: | |
| It can vary depending on Keras, TensorFlow, and TensorFlow Serving version. I used the latest ones. | |
| Beware of the names of the signatures, since they should be used in the client as well. | |
| When creating the client, all preprocessing steps that are needed for the model (preprocess_input() for example) must be executed. I didn't try to add such step in the graph itself as Inception client example. | |
| In case you're curious about the client side, it should be similar to the below one. I added some extra things to use Keras methods for decoding predictions, but it could also be done in the serving side: | |
| ''' | |
| request = predict_pb2.PredictRequest() | |
| request.model_spec.name = 'vgg19' | |
| request.model_spec.signature_name = 'predict' | |
| request.inputs['images'].CopyFrom(tf.contrib.util.make_tensor_proto(img)) | |
| result = stub.Predict(request, 10.0) # 10 secs timeout | |
| to_decode = np.expand_dims(result.outputs['outputs'].float_val, axis=0) | |
| decoded = decode_predictions(to_decode, 5) | |
| print(decoded) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment