Created
July 3, 2020 16:58
-
-
Save khuangaf/5864d31d6c685ca5a68d938a9dc7a5a2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# compute average hit rate for all users | |
def precision_at_k(predictions, k): | |
''' | |
Return the average ndcg for each users | |
args: | |
predictions: np.array user-item predictions | |
returns: | |
hit_rate: float, computed hit rate | |
''' | |
hit = 0 | |
# iterate | |
for target_user in np.unique(val_user_ids): | |
# get movie ids and ratings associated with the target user. | |
target_val_movie_ids = val_movie_ids[val_user_ids == target_user] | |
target_val_ratings = val_ratings[val_user_ids == target_user] | |
clicked_movie_id = target_val_movie_ids[np.argmax(target_val_ratings)] | |
predicted_order = np.argsort(-predictions[target_user]) | |
# if clicked id is within the top k prediction, hit += 1 | |
if np.where(predicted_order == clicked_movie_id)[0][0] < k: | |
hit+=1 | |
return hit / (len(np.unique(val_user_ids)) * k) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment