Last active
April 27, 2017 04:02
-
-
Save hayatoy/5f0a8c978449f6227494fe3e08a7caab to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### How to create TensorFlow SavedModel to be used on Cloud ML Engine Online Prediction\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"PROJECTID = '{project_id}'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Model Definition\n", | |
"and train it" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import absolute_import\n", | |
"from __future__ import division\n", | |
"from __future__ import print_function\n", | |
"\n", | |
"from sklearn import datasets\n", | |
"from sklearn import metrics\n", | |
"from sklearn import cross_validation\n", | |
"\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"# Load dataset.\n", | |
"iris = datasets.load_iris()\n", | |
"x_train, x_test, y_train, y_test = cross_validation.train_test_split(\n", | |
" iris.data, iris.target, test_size=0.2, random_state=42)\n", | |
"\n", | |
"def feature_fn(features, labels):\n", | |
" return features['inputs'], labels\n", | |
"\n", | |
"def get_train_inputs():\n", | |
" x = tf.constant(x_train)\n", | |
" y = tf.constant(y_train)\n", | |
" return {'inputs':x}, y\n", | |
"\n", | |
"def get_test_inputs():\n", | |
" x = tf.constant(x_test)\n", | |
" y = tf.constant(y_test)\n", | |
" return {'inputs':x}, y\n", | |
"\n", | |
"# Build 3 layer DNN with 10, 20, 10 units respectively.\n", | |
"feature_columns = tf.contrib.learn.infer_real_valued_columns_from_input(\n", | |
" x_train)\n", | |
"classifier = tf.contrib.learn.DNNClassifier(\n", | |
" feature_engineering_fn=feature_fn,\n", | |
" feature_columns=feature_columns,\n", | |
" hidden_units=[10, 20, 10],\n", | |
" n_classes=3,\n", | |
" model_dir='/var/tmp/iris')\n", | |
"\n", | |
"# Fit and predict.\n", | |
"classifier.fit(input_fn=get_train_inputs, steps=200)\n", | |
"predictions = list(classifier.predict(input_fn=get_test_inputs, as_iterable=True))\n", | |
"score = metrics.accuracy_score(y_test, predictions)\n", | |
"print('Accuracy: {0:f}'.format(score))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Export as SavedModel" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"from tensorflow.contrib.learn.python.learn.utils import input_fn_utils\n", | |
"\n", | |
"def serving_input_fn():\n", | |
" feature_placeholders = {'inputs': tf.placeholder(tf.float32, [None, 4])}\n", | |
" features = {\n", | |
" key: tensor\n", | |
" for key, tensor in feature_placeholders.items()\n", | |
" } \n", | |
" return input_fn_utils.InputFnOps(\n", | |
" features,\n", | |
" None,\n", | |
" feature_placeholders\n", | |
" )\n", | |
"\n", | |
"gsfilepath = classifier.export_savedmodel('gs://%s-ml/model/iris-saved01' % PROJECTID,\n", | |
" serving_input_fn)\n", | |
"gsfilepath" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Register Model to ML Engine" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from oauth2client.client import GoogleCredentials\n", | |
"from googleapiclient import discovery\n", | |
"from googleapiclient import errors\n", | |
"\n", | |
"\n", | |
"projectID = 'projects/{}'.format(PROJECTID)\n", | |
"modelName = 'irismodel01'\n", | |
"modelID = '{}/models/{}'.format(projectID, modelName)\n", | |
"\n", | |
"credentials = GoogleCredentials.get_application_default()\n", | |
"ml = discovery.build('ml', 'v1', credentials=credentials)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Create Model**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"requestDict = {'name': modelName,\n", | |
" 'description': 'DNNClassifier for iris',\n", | |
" }\n", | |
"\n", | |
"# Create a request to call projects.models.create.\n", | |
"request = ml.projects().models().create(\n", | |
" parent=projectID, body=requestDict)\n", | |
"\n", | |
"# Make the call.\n", | |
"try:\n", | |
" response = request.execute()\n", | |
"except:\n", | |
" print('wow, error')\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Create Version**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"requestDict = {'name': 'v1',\n", | |
" 'deploymentUri': gsfilepath}\n", | |
"\n", | |
"request = ml.projects().models().versions().create(\n", | |
" parent=modelID, body=requestDict)\n", | |
"# Make the call.\n", | |
"try:\n", | |
" response = request.execute()\n", | |
"except:\n", | |
" print('wow, error')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Wait until deployment has finished.." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Online Prediction" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{u'predictions': [{u'classes': 1,\n", | |
" u'scores': [0.00043015286792069674,\n", | |
" 0.9724891781806946,\n", | |
" 0.027080722153186798]},\n", | |
" {u'classes': 0,\n", | |
" u'scores': [0.9987931251525879,\n", | |
" 0.0012068357318639755,\n", | |
" 1.5598407981931572e-11]}]}" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"request_body = {'instances':[{'inputs':[6.1, 2.8, 4.7, 1.2]},\n", | |
" {'inputs':[5.7, 3.8, 1.7, 0.3]}]}\n", | |
"\n", | |
"request = ml.projects().predict(name=modelID, body=request_body)\n", | |
"try:\n", | |
" response = request.execute()\n", | |
"except errors.HttpError as err:\n", | |
" # Something went wrong with the HTTP transaction.\n", | |
" # To use logging, you need to 'import logging'.\n", | |
" print('There was an HTTP error during the request:')\n", | |
" print(err._get_reason())\n", | |
"response" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment