Created
March 11, 2020 05:41
-
-
Save russau/b261d6f70959bd45ab01c59bd983f35e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except | |
# in compliance with the License. A copy of the License is located at | |
# | |
# https://aws.amazon.com/apache-2-0/ | |
# | |
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations under the License. | |
"Serverless front end for factorization machines" | |
# https://aws.amazon.com/blogs/machine-learning/build-a-movie-recommender-with-factorization-machines-on-amazon-sagemaker/ | |
import os | |
import sqlite3 | |
import json | |
import random | |
import boto3 | |
from jinja2 import Environment | |
from jinja2 import FileSystemLoader | |
from chalice import Chalice, Response, NotFoundError, ChaliceViewError | |
app = Chalice(app_name='movie-database') | |
app.debug = True | |
def _render_template(**kwargs): | |
"render jinja template" | |
env = Environment(loader=FileSystemLoader(os.path.abspath(os.path.dirname(__file__)))) | |
template = env.get_template('chalicelib/main.html') | |
rendered_template = template.render(kwargs) | |
return rendered_template | |
@app.route('/') | |
def index(): | |
"homepage redirect to user 1" | |
return Response( | |
status_code=302, | |
body='', | |
headers={'Location': './1'}) | |
@app.route('/random_user') | |
def random_user(): | |
"redirect to a random user" | |
database = get_db() | |
user = query_db(database, "select * from user") | |
user_id = random.choice(user)["userId"] | |
return Response( | |
status_code=302, | |
body='', | |
headers={'Location': './%s' % user_id }) | |
@app.route('/{user_id}') | |
def user_display(user_id): | |
"display a user details" | |
if not user_id.isalnum(): | |
raise NotFoundError("User not found") | |
error_message = None | |
if "SAGEMAKER_ENDPOINT" not in os.environ or "REPLACE_WITH" in os.environ["SAGEMAKER_ENDPOINT"]: | |
error_message = "This application hasn't been configured with a SAGEMAKER_ENDPOINT. Follow the instructions to add the SAGEMAKER_ENDPOINT environment variable." | |
database = get_db() | |
user = query_db(database, "select * from user where userId = ?", (user_id,), one=True) | |
ratings = query_db(database, """select b.rating, i.movieId, i.movieTitle | |
from base b join item i on b.movieId = i.movieId | |
where userId = ? order by i.movieTitle""", | |
(user_id,)) | |
close_connection(database) | |
return Response(_render_template(user=user, ratings=ratings, error_message=error_message), | |
status_code=200, | |
headers={'Content-Type': 'text/html'}) | |
@app.route('/inference/{user_id}') | |
def inference(user_id): | |
"get top movie predictions for the user" | |
if not user_id.isalnum(): | |
raise NotFoundError("Not found") | |
if "SAGEMAKER_ENDPOINT" not in os.environ or "REPLACE_WITH" in os.environ["SAGEMAKER_ENDPOINT"]: | |
raise ChaliceViewError("No SAGEMAKER_ENDPOINT configured") | |
endpoint = os.environ["SAGEMAKER_ENDPOINT"] | |
database = get_db() | |
user_count = query_db(database, 'select count(*) as c from user;', one=True)['c'] | |
movie_count = query_db(database, 'select count(*) as c from item;', one=True)['c'] | |
feature_count = user_count + movie_count | |
print(feature_count, "~~~~") | |
top_ten = [] | |
all_predictions = [] | |
# create the ranges of movieIds we are going to send to the end point | |
step = 100 | |
starts = range(0, movie_count+1, step) | |
ranges = [(s+1, min(s+step, movie_count)) for s in starts] | |
from scipy.sparse import lil_matrix | |
import sagemaker.amazon.common as smac | |
import io | |
sparse_matrix = lil_matrix((movie_count, feature_count)).astype('float32') | |
for line in range(movie_count): | |
sparse_matrix[line,int(user_id)-1] = 1 | |
sparse_matrix[line, user_count + line] = 1 | |
buf = io.BytesIO() | |
smac.write_spmatrix_to_sparse_tensor(buf, sparse_matrix) | |
buf.seek(0) | |
client = boto3.client('sagemaker-runtime') | |
response = client.invoke_endpoint( | |
EndpointName=endpoint, | |
Body=buf, | |
ContentType="application/x-recordio-protobuf" | |
) | |
# for start_end in ranges: | |
# print("Invoking movie range: %s -> %s" % start_end) | |
# query = {"instances": []} | |
# for test_movie_id in range(start_end[0], start_end[1] + 1): | |
# # build the one-hot encoded array of user and movie | |
# movie = ['0'] * feature_count | |
# movie[int(user_id)-1] = 1 | |
# movie[user_count + test_movie_id - 1] = 1 | |
# query["instances"].append({"features": movie}) | |
# | |
# client = boto3.client('sagemaker-runtime') | |
# response = client.invoke_endpoint( | |
# EndpointName=endpoint, | |
# Body=json.dumps(query), | |
# ContentType='application/json' | |
# ) | |
inferences = json.loads(response['Body'].read()) | |
for i, prediction in enumerate(inferences['predictions']): | |
prediction['movieId'] = i+1 | |
# we only want the predicted_label = 1 results | |
positive_predictions = [p for p in inferences['predictions'] if p['predicted_label'] == 1] | |
all_predictions.extend(positive_predictions) | |
# sort by score and grab the top ten | |
top_ten = sorted(all_predictions, key=lambda m: m['score'], reverse=True)[:10] | |
for movie in top_ten: | |
movie_info = query_db(database, """select movieId, movieTitle | |
from item where movieId = ?""", (movie['movieId'],)) | |
movie['movieTitle'] = movie_info[0]['movieTitle'] | |
close_connection(database) | |
return top_ten | |
####### | |
# DATABASE HELPERS | |
####### | |
def make_dicts(cursor, row): | |
"set database to return array of dictionaries" | |
return dict((cursor.description[idx][0], value) | |
for idx, value in enumerate(row)) | |
def get_db(): | |
"get a database connection" | |
database = sqlite3.connect('chalicelib/movies.db') | |
database.row_factory = make_dicts | |
return database | |
def close_connection(database): | |
"close the open database connection" | |
if database is not None: | |
database.close() | |
def query_db(database, query, args=(), one=False): | |
"query the database" | |
cur = database.execute(query, args) | |
val = cur.fetchall() | |
cur.close() | |
return (val[0] if val else None) if one else val |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment