Last active
August 21, 2022 13:20
-
-
Save eribeiro/4630eb4b5562f38fd478d9694aa41ce2 to your computer and use it in GitHub Desktop.
Search relevance evaluation metrics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## | |
## Python implementations of the search relevance evaluation metrics described at | |
## https://opensourceconnections.com/blog/2020/02/28/choosing-your-search-relevance-metric/ | |
## | |
## | |
def precision(docs): | |
return sum(docs) / len(docs) if docs else 0 | |
def avg_precision(docs): | |
vals_to_avg = [precision(docs[:i+1]) for (i, doc) in enumerate(docs) if doc == 1] | |
return sum(vals_to_avg) / len(vals_to_avg) if vals_to_avg else 0 | |
def cumulative_gain(docs): | |
return sum(docs) | |
def discounted_cumulative_gain(docs): | |
from math import log2 | |
scores_to_sum = [d / log2(i+2) for (i,d) in enumerate(docs)] | |
return sum(scores_to_sum) | |
def alternative_discounted_cumulative_gain(docs): | |
from math import log2 | |
scores_to_sum = [(2**d - 1)/ log2(i+2) for (i, d) in enumerate(docs)] | |
return sum(scores_to_sum) | |
def normalized_discounted_cumulative_gain(docs): | |
topK = 5 | |
real = discounted_cumulative_gain(docs[:topK]) | |
ideal = discounted_cumulative_gain(sorted(docs, reverse=True)[:topK]) | |
return real / ideal if ideal else 0 | |
if __name__ == '__main__': | |
docs1 = (1,1,1,0,0) | |
docs2 = (0,0,1,1,1) | |
print("precision") | |
print(precision(docs1), precision(docs2), '\n') | |
print("avg_precision") | |
print(avg_precision(docs1), avg_precision(docs2), '\n') | |
# example used in https://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf | |
docs3 = (1, 1, 0, 1, 0, 0, 1) | |
print("avg_precision") | |
print(avg_precision(docs3), '\n') | |
print("cumulative gain") | |
docsG1 = (4,3,2,1,0) | |
docsG2 = tuple(reversed(docsG1)) | |
print(cumulative_gain(docsG1), cumulative_gain(docsG2), '\n') | |
print("discounted_cumulative_gain") | |
print(discounted_cumulative_gain(docsG1), discounted_cumulative_gain(docsG2), '\n') | |
print("alternative_discounted_cumulative_gain") | |
print(alternative_discounted_cumulative_gain(docsG1), alternative_discounted_cumulative_gain(docsG2), '\n') | |
query1 = (4, 4, 3, 3, 3) | |
query2 = (2, 1, 1, 1, 0) | |
print("alternative_discounted_cumulative_gain") | |
print(alternative_discounted_cumulative_gain(query1), alternative_discounted_cumulative_gain(query2), '\n') | |
print("normalized_discounted_cumulative_gain") | |
print(normalized_discounted_cumulative_gain(query1), normalized_discounted_cumulative_gain(query2), '\n') | |
print("normalized_discounted_cumulative_gain") | |
query3 = (3,2,1,4,0) | |
print(normalized_discounted_cumulative_gain(query3), '\n') | |
print("normalized_discounted_cumulative_gain") | |
query4 = (0,1,2,3,4) | |
print(normalized_discounted_cumulative_gain(query4), '\n') | |
print("normalized_discounted_cumulative_gain") | |
docs_all = (4,3,2,1,1,0,3,4,0,0) | |
print(normalized_discounted_cumulative_gain(docs_all), '\n') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment