Created
December 5, 2020 00:36
-
-
Save mfitton/f4ef50dee03ed8d318443734da70ff24 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import uuid | |
import ray | |
from ray import serve | |
from util import ImpressionStore, choose_ensemble_results | |
@ray.remote | |
class ComposedModel: | |
def __init__(self): | |
# Get handles to the two underlying models. | |
client = serve.connect() | |
self.random_handle = client.get_handle("random") | |
self.plot_handle = client.get_handle("plot") | |
# Store user click data in a detached actor. | |
self.impressions = ImpressionStore.options( | |
lifetime="detached", name="impressions").remote() | |
self.last_impression = None | |
def record_liked_id(self, liked_id): | |
# In reality, we'd want to use something like a session key to | |
# differentiate between users. Here, we'll always use the same one | |
# for simplicity. | |
# session_key = request.args.get("session_key", str(uuid.uuid4())) | |
session_key = "abc123" | |
# Call the two underlying models and get their predictions. | |
results = { | |
"random": ray.get(self.random_handle.remote()), | |
"plot": ray.get(self.plot_handle.remote(liked_id=liked_id)), | |
} | |
# Get the current model distribution. | |
model_distribution = ray.get(self.impressions.model_distribution.remote( | |
session_key, liked_id)) | |
# Select which results to send to the user based on their clicks. | |
distribution, impressions, chosen = choose_ensemble_results( | |
model_distribution, results) | |
self.last_impression = (distribution, impressions, chosen) | |
# Record this click and these recommendations. | |
ray.get(self.impressions.record_impressions.remote( | |
session_key, impressions)) | |
def get_recommendation(self): | |
if self.last_impression: | |
dist, impressions, chosen = self.last_impression | |
return { | |
"dist": dist, | |
"recs": chosen, | |
} | |
else: | |
random_rec = ray.get(self.random_handle.remote()) | |
return { | |
"dist": None, | |
"recs": random_rec, | |
} | |
if __name__ == "__main__": | |
# Deploy the ensemble endpoint. | |
try: | |
ray.init(address="auto") | |
except: | |
raise Exception("Failed to connect to Ray Serve. Did you forget to run setup.py first?") | |
ComposedModel.options(streamlit_script_path="/Users/maxfitton/Downloads/serve-movie-rec-demo/recommender_streamlit.py", name="composed_model").remote() | |
print("Deployed ensemble recommender to /rec/ensemble.") | |
print("Try it out with: 'curl \"http://localhost:8000/rec/ensemble?liked_id=322259\"'") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
import time | |
from datetime import datetime, timedelta | |
from util import get_db_connection | |
from PIL import Image | |
import ray | |
def retry_until_success(f, timeout=15): | |
end = datetime.now() + timedelta(seconds=timeout) | |
while True: | |
try: | |
r = f() | |
return r | |
except Exception as e: | |
if datetime.now() < end: | |
time.sleep(1) | |
continue | |
raise e | |
def rerun(): | |
raise st.script_runner.RerunException(st.script_request_queue.RerunData(None)) | |
if __name__ == "__main__": | |
# Draw a title and some text to the app: | |
''' | |
# Movie recommendations! | |
''' | |
ray.init(address="auto", ignore_reinit_error=True) | |
progress_bar = st.sidebar.progress(0) | |
status_text = st.sidebar.empty() | |
img_width = st.sidebar.slider("Image size", 10, 300, 100) | |
num_movies = st.sidebar.slider("Number of movies", 1, 20, 5) | |
num_cols = st.sidebar.slider("Number of columns", 1, 5, 2) | |
pick_img = st.sidebar.radio("Which movie?", | |
[x for x in range(1, num_movies + 1)]) | |
cols = st.beta_columns(num_cols) | |
ensemble_actor = retry_until_success(lambda: ray.get_actor("composed_model")) | |
result = ray.get(ensemble_actor.get_recommendation.remote()) | |
if result: | |
movies = result["recs"] | |
### List[Movie] where Movie in model results | |
### id: int | |
### title: string | |
for i in range(1, num_movies + 1): | |
col_idx = i % num_cols | |
movie = movies[i-1] | |
# print("Returned: " + str(movie['title']) + ': ' + str(movie['plot'])) | |
status_text.text("%i%% Complete" % (i * 20)) | |
progress_bar.progress(int((i / num_movies) * 100)) | |
image_name = 'assets/' + str(movie['id']) + '.jpg' | |
image = Image.open(image_name) | |
cols[col_idx].image(image, | |
caption=str(i)+':'+movie['title'],width=img_width) | |
ensemble_actor.record_liked_id.remote(str(movies[i-1]["id"])) | |
progress_bar.empty() | |
st.button("Submit preference.") | |
else: | |
st.write("Loading.") | |
# time.sleep(4) | |
# rerun() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import json | |
from collections import defaultdict | |
from itertools import cycle | |
from sqlite3 import connect | |
import numpy as np | |
import pandas as pd | |
import faiss | |
import ray | |
from ray.util.metrics import Gauge | |
def get_db_connection(): | |
path = "/Users/maxfitton/Downloads/serve-movie-rec-demo/composition-demo.sqlite3" | |
if not os.path.exists(path): | |
raise Exception(""" | |
It seems like the database file doesn't exist. Did you forget | |
to download it? | |
wget https://ray-serve-blog.s3-us-west-2.amazonaws.com/composition-demo.sqlite3 | |
""") | |
return connect(path) | |
@ray.remote(num_cpus=0) | |
class ImpressionStore: | |
def __init__(self): | |
# session_key -> {id: model} | |
self.impressions = defaultdict(dict) | |
# session_key -> number of impression recorded | |
self.num_impressions = defaultdict(lambda: 0) | |
# session_key -> {model_name: int} | |
self.session_liked_model_count = defaultdict( | |
lambda: defaultdict(lambda: 0)) | |
# model -> cliked_id_set | |
self.model_clicked_ids = defaultdict(set) | |
# model -> shown_id_set | |
self.model_shown_ids = defaultdict(set) | |
self.metric = Gauge( | |
"impression_store_click_rate", | |
"The click through rate of each model in impression store", | |
"percent", ["model"]) | |
def _refresh_ctr(self): | |
model_counter = defaultdict(lambda: 0) | |
for liked_count in self.session_liked_model_count.values(): | |
for name, count in liked_count.items(): | |
model_counter[name] += count | |
for model, clicks in model_counter.items(): | |
rate = clicks / self.num_impressions[model] | |
self.metric.record(rate, {"model": model}) | |
def _record_feedback(self, session_key, liked_id): | |
# Record feedback from the user | |
src_model = self.impressions[session_key].get(liked_id) | |
# Can't find this impression source | |
if src_model is None: | |
return | |
self.session_liked_model_count[session_key][src_model] += 1 | |
self.model_clicked_ids[src_model].add(liked_id) | |
self._refresh_ctr() | |
def record_impressions(self, session_key, impressions): | |
# Record impressions we are sending out | |
for model, ids in impressions.items(): | |
for movie_payload in ids: | |
movie_id = movie_payload["id"] | |
self.impressions[session_key][movie_id] = model | |
self.model_shown_ids[model].add(movie_id) | |
self.num_impressions[model] += 1 | |
self._refresh_ctr() | |
def model_distribution(self, session_key, liked_id): | |
if session_key == "": | |
return {} | |
self._record_feedback(session_key, liked_id) | |
return self.session_liked_model_count[session_key] | |
def count_for_model(self, model): | |
count = 0 | |
for model_dict in self.session_liked_model_count.values(): | |
if model in model_dict: | |
count += model_dict[model] | |
return count | |
def get_model_clicks(self, model): | |
positive = self.model_clicked_ids[model] | |
negative = self.model_shown_ids[model] - positive | |
return pd.DataFrame({ | |
"id": list(positive) + list(negative), | |
"clicked": [1] * len(positive) + [0] * len(negative) | |
}) | |
def choose_ensemble_results(model_distribution, model_results): | |
# Normalize dist | |
if len(model_distribution) != 2: | |
default_dist = {model: 1 for model in ["random", "plot"]} | |
for name, count in model_distribution.items(): | |
default_dist[name] += count | |
else: | |
default_dist = model_distribution | |
total_weights = sum(default_dist.values()) | |
normalized_distribution = { | |
k: v / total_weights | |
for k, v in default_dist.items() | |
} | |
# Generate num returns | |
chosen = [] | |
impressions = defaultdict(list) | |
dominant_group = max( | |
list(normalized_distribution.keys()), | |
key=lambda k: normalized_distribution[k]) | |
sorted_group = list( | |
sorted( | |
normalized_distribution.keys(), | |
key=lambda k: -normalized_distribution[k])) | |
if normalized_distribution[sorted_group[0]] > normalized_distribution[sorted_group[1]]: | |
sorted_group = [dominant_group] + sorted_group | |
# Rank based on weights | |
groups = cycle(sorted_group) | |
while len(chosen) <= 6: | |
model = next(groups) | |
preds = model_results[model] | |
if len(preds) == 0: | |
if model == dominant_group: | |
break | |
else: | |
continue | |
movie = preds.pop(0) | |
movie["model"] = model | |
if movie not in chosen: | |
impressions[model].append(movie) | |
chosen.append(movie) | |
return normalized_distribution, impressions, chosen | |
class LRMovieRanker: | |
def __init__(self, lr_model, features): | |
self.lr_model = lr_model | |
self.features = features | |
def rank_movies(self, recommended_movies): | |
vectors = np.array([self.features[i] for i in recommended_movies]) | |
ranks = self.lr_model.predict_proba(vectors)[:, 1].flatten() | |
high_to_low_idx = np.argsort(ranks).tolist()[::-1] | |
return [recommended_movies[i] for i in high_to_low_idx] | |
class KNearestNeighborIndex: | |
def __init__(self, db_cursor): | |
# Query all the cover image palette | |
self.id_to_arr = { | |
row[0]: np.array(json.loads(row[1])).flatten() | |
for row in db_cursor | |
} | |
vector_length = len(next(iter(self.id_to_arr.values()))) | |
self.index = faiss.IndexIDMap(faiss.IndexFlatL2(vector_length)) | |
# Build the index | |
arr = np.stack(list(self.id_to_arr.values())).astype('float32') | |
ids = np.array(list(self.id_to_arr.keys())).astype('int') | |
self.index.add_with_ids(arr, ids) | |
def search(self, request): | |
liked_id = request.args["liked_id"] | |
num_returns = int(request.args.get("count", 5)) | |
# Perform nearest neighbor search | |
source_color = self.id_to_arr[liked_id] | |
source_color = np.expand_dims(source_color, 0).astype('float32') | |
_, ids = self.index.search(source_color, num_returns+1) | |
neighbors = ids.flatten().tolist() | |
ret = [] | |
for n in neighbors: | |
if str(n) != liked_id: | |
ret.append(str(n)) | |
if len(ret) == num_returns: | |
break | |
return ret |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment