Created
July 21, 2017 20:28
-
-
Save zomglings/fc1ec4ee218963216dea3ab5242bf611 to your computer and use it in GitHub Desktop.
serving-input-reception-bug
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Mishandling of estimators with single, unnamed feature\n", | |
"\n", | |
"The [tf.estimator.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) interface allows users to provide a `model_fn` which accepts features either within a single tensor or within a dictionary mapping strings to tensors.\n", | |
"\n", | |
"The Estimator `export_savedmodel` method requires a `serving_input_receiver_fn` argument, which is a function of no arguments that produces a [ServingInputReceiver](https://www.tensorflow.org/api_docs/python/tf/estimator/export/ServingInputReceiver). The features tensors from this `ServingInputReceiver` are passed to the `model_fn` for serving.\n", | |
"\n", | |
"Upon instantiation, the `ServingInputReceiver` wraps single tensor features into a dictionary. This raises an error for estimators whose `model_fn` expects a single tensor as its `features` argument." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Sample estimator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def model_fn(features, labels, mode): \n", | |
" output = tf.multiply(features, features)\n", | |
" \n", | |
" prediction_output = tf.estimator.export.PredictOutput({'trololol': output})\n", | |
" \n", | |
" return tf.estimator.EstimatorSpec(\n", | |
" mode=mode,\n", | |
" predictions=features,\n", | |
" train_op=tf.no_op(),\n", | |
" loss=tf.constant(1, dtype=tf.float32),\n", | |
" export_outputs={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_output}\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"MODEL_DIR = '/tmp/dummy-model'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [], | |
"source": [ | |
"dummy = tf.estimator.Estimator(model_fn, model_dir=MODEL_DIR)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dummy.train(lambda: (tf.constant([1, 2, 3, 4, 5], dtype=tf.float32), None), steps=1)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Export woes\n", | |
"\n", | |
"The following `serving_input_receiver_fn` uses `ServingInputReceiver`;" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def serving_input_receiver_fn():\n", | |
" feature_tensor = tf.placeholder(tf.float32, [None, 1])\n", | |
" return tf.estimator.export.ServingInputReceiver(feature_tensor, {'input': feature_tensor})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"BASE_EXPORT_DIR = '/tmp/dummy-model/export'" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dummy.export_savedmodel(BASE_EXPORT_DIR, serving_input_receiver_fn)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"As you can see, there is no mechanism in the `model_fn` to handle this dictionary, and we can't expect there to be one because the user provides the implementation." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Fix\n", | |
"\n", | |
"Either we should change the estimator interface to only accept a `model_fn` which takes its features in a dictionary, OR we should provide a different type of `ServingInputReceiver`. For example, something like this:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class FlatServingInputReceiver(object):\n", | |
" def __init__(self, feature):\n", | |
" self.features = feature\n", | |
" self.receiver_tensors = {'input': feature}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"def flat_serving_input_receiver_fn():\n", | |
" feature_tensor = tf.placeholder(tf.float32, [None, 1])\n", | |
" return FlatServingInputReceiver(feature_tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dummy.export_savedmodel(BASE_EXPORT_DIR, flat_serving_input_receiver_fn)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment