Skip to content

Instantly share code, notes, and snippets.

@russau
Created March 11, 2020 05:41
Show Gist options
  • Save russau/b261d6f70959bd45ab01c59bd983f35e to your computer and use it in GitHub Desktop.
Save russau/b261d6f70959bd45ab01c59bd983f35e to your computer and use it in GitHub Desktop.
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except
# in compliance with the License. A copy of the License is located at
#
# https://aws.amazon.com/apache-2-0/
#
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
"Serverless front end for factorization machines"
# https://aws.amazon.com/blogs/machine-learning/build-a-movie-recommender-with-factorization-machines-on-amazon-sagemaker/
import os
import sqlite3
import json
import random
import boto3
from jinja2 import Environment
from jinja2 import FileSystemLoader
from chalice import Chalice, Response, NotFoundError, ChaliceViewError
app = Chalice(app_name='movie-database')
app.debug = True
def _render_template(**kwargs):
"render jinja template"
env = Environment(loader=FileSystemLoader(os.path.abspath(os.path.dirname(__file__))))
template = env.get_template('chalicelib/main.html')
rendered_template = template.render(kwargs)
return rendered_template
@app.route('/')
def index():
"homepage redirect to user 1"
return Response(
status_code=302,
body='',
headers={'Location': './1'})
@app.route('/random_user')
def random_user():
"redirect to a random user"
database = get_db()
user = query_db(database, "select * from user")
user_id = random.choice(user)["userId"]
return Response(
status_code=302,
body='',
headers={'Location': './%s' % user_id })
@app.route('/{user_id}')
def user_display(user_id):
"display a user details"
if not user_id.isalnum():
raise NotFoundError("User not found")
error_message = None
if "SAGEMAKER_ENDPOINT" not in os.environ or "REPLACE_WITH" in os.environ["SAGEMAKER_ENDPOINT"]:
error_message = "This application hasn't been configured with a SAGEMAKER_ENDPOINT. Follow the instructions to add the SAGEMAKER_ENDPOINT environment variable."
database = get_db()
user = query_db(database, "select * from user where userId = ?", (user_id,), one=True)
ratings = query_db(database, """select b.rating, i.movieId, i.movieTitle
from base b join item i on b.movieId = i.movieId
where userId = ? order by i.movieTitle""",
(user_id,))
close_connection(database)
return Response(_render_template(user=user, ratings=ratings, error_message=error_message),
status_code=200,
headers={'Content-Type': 'text/html'})
@app.route('/inference/{user_id}')
def inference(user_id):
"get top movie predictions for the user"
if not user_id.isalnum():
raise NotFoundError("Not found")
if "SAGEMAKER_ENDPOINT" not in os.environ or "REPLACE_WITH" in os.environ["SAGEMAKER_ENDPOINT"]:
raise ChaliceViewError("No SAGEMAKER_ENDPOINT configured")
endpoint = os.environ["SAGEMAKER_ENDPOINT"]
database = get_db()
user_count = query_db(database, 'select count(*) as c from user;', one=True)['c']
movie_count = query_db(database, 'select count(*) as c from item;', one=True)['c']
feature_count = user_count + movie_count
print(feature_count, "~~~~")
top_ten = []
all_predictions = []
# create the ranges of movieIds we are going to send to the end point
step = 100
starts = range(0, movie_count+1, step)
ranges = [(s+1, min(s+step, movie_count)) for s in starts]
from scipy.sparse import lil_matrix
import sagemaker.amazon.common as smac
import io
sparse_matrix = lil_matrix((movie_count, feature_count)).astype('float32')
for line in range(movie_count):
sparse_matrix[line,int(user_id)-1] = 1
sparse_matrix[line, user_count + line] = 1
buf = io.BytesIO()
smac.write_spmatrix_to_sparse_tensor(buf, sparse_matrix)
buf.seek(0)
client = boto3.client('sagemaker-runtime')
response = client.invoke_endpoint(
EndpointName=endpoint,
Body=buf,
ContentType="application/x-recordio-protobuf"
)
# for start_end in ranges:
# print("Invoking movie range: %s -> %s" % start_end)
# query = {"instances": []}
# for test_movie_id in range(start_end[0], start_end[1] + 1):
# # build the one-hot encoded array of user and movie
# movie = ['0'] * feature_count
# movie[int(user_id)-1] = 1
# movie[user_count + test_movie_id - 1] = 1
# query["instances"].append({"features": movie})
#
# client = boto3.client('sagemaker-runtime')
# response = client.invoke_endpoint(
# EndpointName=endpoint,
# Body=json.dumps(query),
# ContentType='application/json'
# )
inferences = json.loads(response['Body'].read())
for i, prediction in enumerate(inferences['predictions']):
prediction['movieId'] = i+1
# we only want the predicted_label = 1 results
positive_predictions = [p for p in inferences['predictions'] if p['predicted_label'] == 1]
all_predictions.extend(positive_predictions)
# sort by score and grab the top ten
top_ten = sorted(all_predictions, key=lambda m: m['score'], reverse=True)[:10]
for movie in top_ten:
movie_info = query_db(database, """select movieId, movieTitle
from item where movieId = ?""", (movie['movieId'],))
movie['movieTitle'] = movie_info[0]['movieTitle']
close_connection(database)
return top_ten
#######
# DATABASE HELPERS
#######
def make_dicts(cursor, row):
"set database to return array of dictionaries"
return dict((cursor.description[idx][0], value)
for idx, value in enumerate(row))
def get_db():
"get a database connection"
database = sqlite3.connect('chalicelib/movies.db')
database.row_factory = make_dicts
return database
def close_connection(database):
"close the open database connection"
if database is not None:
database.close()
def query_db(database, query, args=(), one=False):
"query the database"
cur = database.execute(query, args)
val = cur.fetchall()
cur.close()
return (val[0] if val else None) if one else val
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment