Created
July 3, 2020 16:58
-
-
Save khuangaf/96f1103e837e91339039b010ae53a07d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# compute average hit rate for all users | |
def precision_at_k(predictions, k): | |
''' | |
Return the average ndcg for each users | |
args: | |
predictions: np.array user-item predictions | |
returns: | |
hit_rate: float, computed hit rate | |
''' | |
hit = 0 | |
# iterate | |
for target_user in np.unique(val_user_ids): | |
# get movie ids and ratings associated with the target user. | |
target_val_movie_ids = val_movie_ids[val_user_ids == target_user] | |
target_val_ratings = val_ratings[val_user_ids == target_user] | |
clicked_movie_id = target_val_movie_ids[np.argmax(target_val_ratings)] | |
predicted_order = np.argsort(-predictions[target_user]) | |
# if clicked id is within the top k prediction, hit += 1 | |
if np.where(predicted_order == clicked_movie_id)[0][0] < k: | |
hit+=1 | |
return hit / (len(np.unique(val_user_ids)) * k) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment