Created
April 23, 2018 12:09
-
-
Save thisisjl/04e5fae58883c58b5df727c183e64c71 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import measures | |
import random | |
import string | |
import numpy as np | |
np.random.seed(0) | |
n_digits = 3 | |
n_catalog = 50 | |
n_ranking = 10 | |
k = 20 | |
# create a catalog of items and a reference and predicted ranking | |
catalog = [''.join(random.choices(string.ascii_uppercase + string.digits, k=n_digits)) for _ in range(n_catalog)] | |
reference_ranking = np.random.choice(catalog,n_ranking,replace=False) | |
predicted_ranking = np.random.choice(catalog,n_ranking,replace=False) | |
# find the ground truth relevance of the items in the predicted ranking | |
ground_truth_relevance = np.array([int(predicted_item in reference_ranking) for predicted_item in predicted_ranking]) | |
predicted_relevance = np.ones(n_ranking,dtype=int) | |
true_order = np.argsort(-ground_truth_relevance) | |
reference = [ground_truth_relevance[x] for x in true_order] | |
hypothesis = [predicted_relevance[x] for x in true_order] # equals to predicted_relevance | |
# evaluate | |
## precision at k | |
my_precision = len(set(reference_ranking[:k])&set(predicted_ranking[:k]))/k | |
## precision | |
precision_k = measures.find_precision_k(reference,hypothesis,k) | |
## ndcg | |
ndcg = measures.find_ndcg(reference,hypothesis) | |
## rankdcg | |
rankdcg = measures.find_rankdcg(reference,hypothesis) | |
# rankdcg = measures.find_rankdcg(reference,np.random.permutation(reference)) | |
print(reference) | |
print(hypothesis) | |
print('my_precision',my_precision) | |
print('precision_k',precision_k) | |
print('ndcg',ndcg) | |
print('rankndcg',rankdcg) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment