Skip to content

Instantly share code, notes, and snippets.

@zomglings
Created July 21, 2017 20:28
Show Gist options
  • Save zomglings/fc1ec4ee218963216dea3ab5242bf611 to your computer and use it in GitHub Desktop.
Save zomglings/fc1ec4ee218963216dea3ab5242bf611 to your computer and use it in GitHub Desktop.
serving-input-reception-bug
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mishandling of estimators with single, unnamed feature\n",
"\n",
"The [tf.estimator.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator) interface allows users to provide a `model_fn` which accepts features either within a single tensor or within a dictionary mapping strings to tensors.\n",
"\n",
"The Estimator `export_savedmodel` method requires a `serving_input_receiver_fn` argument, which is a function of no arguments that produces a [ServingInputReceiver](https://www.tensorflow.org/api_docs/python/tf/estimator/export/ServingInputReceiver). The features tensors from this `ServingInputReceiver` are passed to the `model_fn` for serving.\n",
"\n",
"Upon instantiation, the `ServingInputReceiver` wraps single tensor features into a dictionary. This raises an error for estimators whose `model_fn` expects a single tensor as its `features` argument."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sample estimator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def model_fn(features, labels, mode): \n",
" output = tf.multiply(features, features)\n",
" \n",
" prediction_output = tf.estimator.export.PredictOutput({'trololol': output})\n",
" \n",
" return tf.estimator.EstimatorSpec(\n",
" mode=mode,\n",
" predictions=features,\n",
" train_op=tf.no_op(),\n",
" loss=tf.constant(1, dtype=tf.float32),\n",
" export_outputs={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_output}\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"MODEL_DIR = '/tmp/dummy-model'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"dummy = tf.estimator.Estimator(model_fn, model_dir=MODEL_DIR)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dummy.train(lambda: (tf.constant([1, 2, 3, 4, 5], dtype=tf.float32), None), steps=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Export woes\n",
"\n",
"The following `serving_input_receiver_fn` uses `ServingInputReceiver`;"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def serving_input_receiver_fn():\n",
" feature_tensor = tf.placeholder(tf.float32, [None, 1])\n",
" return tf.estimator.export.ServingInputReceiver(feature_tensor, {'input': feature_tensor})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"BASE_EXPORT_DIR = '/tmp/dummy-model/export'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dummy.export_savedmodel(BASE_EXPORT_DIR, serving_input_receiver_fn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see, there is no mechanism in the `model_fn` to handle this dictionary, and we can't expect there to be one because the user provides the implementation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fix\n",
"\n",
"Either we should change the estimator interface to only accept a `model_fn` which takes its features in a dictionary, OR we should provide a different type of `ServingInputReceiver`. For example, something like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class FlatServingInputReceiver(object):\n",
" def __init__(self, feature):\n",
" self.features = feature\n",
" self.receiver_tensors = {'input': feature}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def flat_serving_input_receiver_fn():\n",
" feature_tensor = tf.placeholder(tf.float32, [None, 1])\n",
" return FlatServingInputReceiver(feature_tensor)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dummy.export_savedmodel(BASE_EXPORT_DIR, flat_serving_input_receiver_fn)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment