Created
July 8, 2020 06:12
-
-
Save jacobobryant/cefc012c0e86afa9582c4ac6bcb60b82 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from surprise import KNNBaseline | |
from surprise import Dataset | |
from surprise import Reader | |
import time | |
import threading | |
import sys | |
import json | |
def synchronized(func): | |
func.__lock__ = threading.Lock() | |
def synced_func(*args, **kws): | |
with func.__lock__: | |
return func(*args, **kws) | |
return synced_func | |
def train(filename): | |
"""Returns algo object. Input file is a CSV of user_id, item_id, rating.""" | |
reader = Reader(line_format='user item rating', sep=',', rating_scale=[-1,1]) | |
data = Dataset.load_from_file(filename, reader=reader) | |
trainset = data.build_full_trainset() | |
algo = KNNBaseline(sim_options={'user_based': False}) | |
algo.fit(trainset) | |
return algo | |
@synchronized | |
def set_rating(algo, user_id, item_id, rating): | |
"""Should work even if user_id is new. Nice if it works with new item_ids too. | |
If rating is None, it means unset the rating. | |
You just need to update anything needed by the top function.""" | |
# need to update: | |
# algo.yr | |
# algo.bu | |
# algo.by | |
# algo._raw2inner_id_users | |
# optional: | |
# trainset.to_inner_iid | |
# trainset.knows_item | |
# algo.bi | |
# algo.bx | |
# algo.sim | |
# algo.ir | |
# algo._raw2inner_id_items | |
if not algo.trainset.knows_item(item_id): | |
return | |
inner_item_id = algo.trainset.to_inner_iid(item_id) | |
try: | |
inner_user_id = algo.trainset.to_inner_uid(user_id) | |
new_user = False | |
except ValueError: | |
inner_user_id = len(algo.trainset._raw2inner_id_users) | |
new_user = True | |
new_bu = np.append(algo.bu, 0) | |
algo.bu = new_bu | |
algo.by = new_bu | |
algo.yr[inner_user_id] = [] | |
new_ratings = [(i, r) for i, r in algo.yr[inner_user_id] if i != inner_item_id] | |
if rating is not None: | |
new_ratings += [(inner_item_id, rating)] | |
algo.yr[inner_user_id] = new_ratings | |
if new_user: | |
# Do this last to make sure score doesn't get messed up if it calls while | |
# this function is executing (a likely occurrence). | |
algo.trainset._raw2inner_id_users[user_id] = inner_user_id | |
def normalize_pred(pred): | |
ret = {'score': pred.est, | |
'item-id': pred.iid} | |
if 'actual_k' in pred.details: | |
ret['knn/k'] = pred.details['actual_k'] | |
else: | |
ret['knn/k'] = 0 | |
return ret | |
def top(algo, user_id, item_ids, verbose=False): | |
"""Returns a list of dicts corresponding to item_ids. Contains a "score" key | |
and any other keys that should be recorded along with the recommendation.""" | |
start = time.monotonic() | |
ret = sorted([normalize_pred(algo.predict(user_id, item_id, clip=False)) | |
for item_id in item_ids], | |
key=lambda x: (x['score'], x['knn/k']), | |
reverse=True)[:20] | |
if verbose: | |
print("top:", time.monotonic() - start) | |
return ret | |
def main(ratings_file, user_candidates_file): | |
algo = train(ratings_file) | |
with open(user_candidates_file, 'r') as f: | |
user_id, item_ids = json.loads(f.read()) | |
score(algo, user_id, item_ids, verbose=True) | |
if __name__ == "__main__": | |
main(*sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment