Last active
August 1, 2022 07:13
-
-
Save Chiang97912/5ebca5f3fa58eff4a096119dd356e032 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding:utf-8 -*- | |
def MRR(ranked_list, ground_truth): | |
""" 平均倒排名 """ | |
rr = 0. | |
for i in range(len(ranked_list)): | |
for j in range(len(ranked_list[i])): | |
# if ground_truth[i][0] == ranked_list[i][j]: | |
if ranked_list[i][j] in ground_truth[i]: | |
rr += 1/(j+1) # 注意j的取值从0开始 | |
break | |
mrr = rr / len(ground_truth) | |
return mrr | |
def HitRatio(ranked_list, ground_truth): | |
""" 命中率 """ | |
hits = 0. | |
for i in range(len(ranked_list)): | |
recom_set = set(ranked_list[i]) | |
truth_set = set(ground_truth[i]) | |
n_union = len(recom_set & truth_set) | |
if n_union > 0: | |
hits += 1 | |
return hits / len(ground_truth) | |
def AP(ranked_list, ground_truth): | |
""" 精度均值(Average Precision,简称AP) """ | |
hits = 0 | |
sum_precs = 0 | |
for i in range(len(ranked_list)): | |
if ranked_list[i] in ground_truth: | |
hits += 1 | |
sum_precs += hits / (i + 1.0) | |
if hits > 0: | |
return sum_precs / len(ground_truth) | |
else: | |
return 0 | |
def MAP(ranked_list, ground_truth): | |
""" 平均精度均值(Mean Average Precision,简称MAP) """ | |
ap = 0 | |
mAP = 0 | |
for i in range(len(ranked_list)): | |
ap += AP(ranked_list[i], ground_truth[i]) | |
mAP = ap / len(ground_truth) | |
return mAP | |
def Precision(ranked_list, ground_truth): | |
""" 精确率 """ | |
n_union = 0. | |
recommend_sum = 0. | |
for i in range(len(ranked_list)): | |
recom_set = set(ranked_list[i]) | |
truth_set = set(ground_truth[i]) | |
n_union += len(recom_set & truth_set) | |
recommend_sum += len(recom_set) | |
return n_union / recommend_sum | |
def Recall(ranked_list, ground_truth): | |
""" 召回率 """ | |
n_union = 0. | |
user_sum = 0. | |
for i in range(len(ranked_list)): | |
recom_set = set(ranked_list[i]) | |
truth_set = set(ground_truth[i]) | |
n_union += len(recom_set & truth_set) | |
user_sum += len(truth_set) | |
return n_union / user_sum | |
def Precision_V1(recommends, tests): | |
""" 精确率 """ | |
n_union = 0. | |
recommend_sum = 0. | |
for user_id, items in recommends.items(): | |
recommend_set = set(items) | |
test_set = set(tests[user_id]) | |
n_union += len(recommend_set & test_set) | |
recommend_sum += len(recommend_set) | |
return n_union / recommend_sum | |
def Recall_V1(recommends, tests): | |
""" 召回率 """ | |
n_union = 0. | |
user_sum = 0. | |
for user_id, items in recommends.items(): | |
recommend_set = set(items) | |
test_set = set(tests[user_id]) | |
n_union += len(recommend_set & test_set) | |
user_sum += len(test_set) | |
return n_union / user_sum | |
if __name__ == '__main__': | |
# 推荐列表 | |
R = [[3, 10, 15, 12, 17], [20, 15, 18, 14, 30], [2, 5, 7, 8, 15], [56, 14, 25, 12, 19], [21, 24, 36, 54, 45]] | |
# 用户访问列表 | |
T = [[12], [3], [5], [14], [20]] | |
# T = [[12, 3, 17, 15], [3], [5, 15, 8], [14], [20, 24]] | |
print(MRR(R, T)) | |
print(HitRatio(R, T)) | |
print(MAP(R, T)) | |
print(Precision(R, T)) | |
print(Recall(R, T)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment