Skip to content

Instantly share code, notes, and snippets.

@hayatoy
Last active April 27, 2017 04:02
Show Gist options
  • Save hayatoy/5f0a8c978449f6227494fe3e08a7caab to your computer and use it in GitHub Desktop.
Save hayatoy/5f0a8c978449f6227494fe3e08a7caab to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### How to create TensorFlow SavedModel to be used on Cloud ML Engine Online Prediction\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"PROJECTID = '{project_id}'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Model Definition\n",
"and train it"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from __future__ import absolute_import\n",
"from __future__ import division\n",
"from __future__ import print_function\n",
"\n",
"from sklearn import datasets\n",
"from sklearn import metrics\n",
"from sklearn import cross_validation\n",
"\n",
"import tensorflow as tf\n",
"\n",
"# Load dataset.\n",
"iris = datasets.load_iris()\n",
"x_train, x_test, y_train, y_test = cross_validation.train_test_split(\n",
" iris.data, iris.target, test_size=0.2, random_state=42)\n",
"\n",
"def feature_fn(features, labels):\n",
" return features['inputs'], labels\n",
"\n",
"def get_train_inputs():\n",
" x = tf.constant(x_train)\n",
" y = tf.constant(y_train)\n",
" return {'inputs':x}, y\n",
"\n",
"def get_test_inputs():\n",
" x = tf.constant(x_test)\n",
" y = tf.constant(y_test)\n",
" return {'inputs':x}, y\n",
"\n",
"# Build 3 layer DNN with 10, 20, 10 units respectively.\n",
"feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(\n",
" x_train)\n",
"classifier = tf.contrib.learn.DNNClassifier(\n",
" feature_engineering_fn=feature_fn,\n",
" feature_columns=feature_columns,\n",
" hidden_units=[10, 20, 10],\n",
" n_classes=3,\n",
" model_dir='/var/tmp/iris')\n",
"\n",
"# Fit and predict.\n",
"classifier.fit(input_fn=get_train_inputs, steps=200)\n",
"predictions = list(classifier.predict(input_fn=get_test_inputs, as_iterable=True))\n",
"score = metrics.accuracy_score(y_test, predictions)\n",
"print('Accuracy: {0:f}'.format(score))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Export as SavedModel"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from tensorflow.contrib.learn.python.learn.utils import input_fn_utils\n",
"\n",
"def serving_input_fn():\n",
" feature_placeholders = {'inputs': tf.placeholder(tf.float32, [None, 4])}\n",
" features = {\n",
" key: tensor\n",
" for key, tensor in feature_placeholders.items()\n",
" } \n",
" return input_fn_utils.InputFnOps(\n",
" features,\n",
" None,\n",
" feature_placeholders\n",
" )\n",
"\n",
"gsfilepath = classifier.export_savedmodel('gs://%s-ml/model/iris-saved01' % PROJECTID,\n",
" serving_input_fn)\n",
"gsfilepath"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Register Model to ML Engine"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from oauth2client.client import GoogleCredentials\n",
"from googleapiclient import discovery\n",
"from googleapiclient import errors\n",
"\n",
"\n",
"projectID = 'projects/{}'.format(PROJECTID)\n",
"modelName = 'irismodel01'\n",
"modelID = '{}/models/{}'.format(projectID, modelName)\n",
"\n",
"credentials = GoogleCredentials.get_application_default()\n",
"ml = discovery.build('ml', 'v1', credentials=credentials)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Create Model**"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"\n",
"requestDict = {'name': modelName,\n",
" 'description': 'DNNClassifier for iris',\n",
" }\n",
"\n",
"# Create a request to call projects.models.create.\n",
"request = ml.projects().models().create(\n",
" parent=projectID, body=requestDict)\n",
"\n",
"# Make the call.\n",
"try:\n",
" response = request.execute()\n",
"except:\n",
" print('wow, error')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Create Version**"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"requestDict = {'name': 'v1',\n",
" 'deploymentUri': gsfilepath}\n",
"\n",
"request = ml.projects().models().versions().create(\n",
" parent=modelID, body=requestDict)\n",
"# Make the call.\n",
"try:\n",
" response = request.execute()\n",
"except:\n",
" print('wow, error')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Wait until deployment has finished.."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Online Prediction"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{u'predictions': [{u'classes': 1,\n",
" u'scores': [0.00043015286792069674,\n",
" 0.9724891781806946,\n",
" 0.027080722153186798]},\n",
" {u'classes': 0,\n",
" u'scores': [0.9987931251525879,\n",
" 0.0012068357318639755,\n",
" 1.5598407981931572e-11]}]}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"request_body = {'instances':[{'inputs':[6.1, 2.8, 4.7, 1.2]},\n",
" {'inputs':[5.7, 3.8, 1.7, 0.3]}]}\n",
"\n",
"request = ml.projects().predict(name=modelID, body=request_body)\n",
"try:\n",
" response = request.execute()\n",
"except errors.HttpError as err:\n",
" # Something went wrong with the HTTP transaction.\n",
" # To use logging, you need to 'import logging'.\n",
" print('There was an HTTP error during the request:')\n",
" print(err._get_reason())\n",
"response"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment