Skip to content

Instantly share code, notes, and snippets.

@yubink
Created September 28, 2023 20:35
Show Gist options
  • Save yubink/a72896bbb2bdf966a0edf8130df70181 to your computer and use it in GitHub Desktop.
Save yubink/a72896bbb2bdf966a0edf8130df70181 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "2030ff9d-7ebd-45c1-8310-746fd4b9ac3f",
"metadata": {},
"source": [
"Recall is the measure that says, how many things you were suppose to find did you actually find?\n",
"\n",
"That is, recall = true_positive/(true_positive + false_negative)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "edba6a06",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"# recall: true_positives / (true_positives + false_negatives)\n",
"# i.e. stuff you got right / all the right stuff\n",
"def recall(gold, retrieved):\n",
" tp = 0\n",
" for item in retrieved:\n",
" if item in gold:\n",
" tp += 1\n",
" return tp/len(gold)\n",
" #one liner: len([x for x in retrieved if x in gold])/len(gold)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0ab61af0-20a1-4a9f-ad3a-23e1e1fa9b2d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Recall of ground truth: [1, 2]\n",
"[0, 0, 0, 2, 0, 0, 0, 0, 0, 0] => 0.5\n",
"[0, 1, 0, 2, 0, 0, 0, 0, 0, 0] => 1.0\n"
]
}
],
"source": [
"ground_truth = {'q1': [1, 2], 'q2': [1]}\n",
"\n",
"model0 = {'q1': [0, 0, 0, 2, 0, 0, 0, 0, 0, 0]}\n",
"\n",
"print(\"Recall of ground truth:\", ground_truth['q1'])\n",
"print(model0['q1'], '=>', recall(ground_truth['q1'], model0['q1']))\n",
"\n",
"model1 = {'q1': [0, 1, 0, 2, 0, 0, 0, 0, 0, 0]}\n",
"print(model1['q1'], '=>', recall(ground_truth['q1'], model1['q1']))"
]
},
{
"cell_type": "markdown",
"id": "1cdb6934-0283-4f60-bb02-70a896a8c7b5",
"metadata": {},
"source": [
"One thing to note is that recall is not rank-sensitive. It doesn't care in what rank you put the true positive items in."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3faec046-284c-43b1-b8a9-a224896e844b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Recall of ground truth: [1, 2]\n",
"[1, 2, 0, 0, 0, 0, 0, 0, 0, 0] => 1.0\n",
"[0, 0, 0, 0, 0, 0, 0, 0, 1, 2] => 1.0\n"
]
}
],
"source": [
"model2 = {'q1': [1, 2, 0, 0, 0, 0, 0, 0, 0, 0]}\n",
"print(\"Recall of ground truth:\", ground_truth['q1'])\n",
"print(model2['q1'], '=>', recall(ground_truth['q1'], model2['q1']))\n",
"\n",
"model3 = {'q1': [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]}\n",
"print(model3['q1'], '=>', recall(ground_truth['q1'], model3['q1']))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "56c7941c-6122-4645-94f2-003c99fa8d1f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Recall of ground truth: [1, 2]\n",
"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0] => 0.5\n",
"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0] => 0.5\n",
"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0] => 0.5\n",
"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0] => 0.5\n",
"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0] => 0.5\n",
"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0] => 0.5\n",
"[0, 0, 0, 0, 0, 0, 1, 0, 0, 0] => 0.5\n",
"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0] => 0.5\n",
"[0, 0, 0, 0, 0, 0, 0, 0, 1, 0] => 0.5\n"
]
}
],
"source": [
"print(\"Recall of ground truth:\", ground_truth['q1'])\n",
"\n",
"for i in range(9):\n",
" model = [0]*i + [1] + [0]*(9-i)\n",
" print(model, '=>', recall(ground_truth['q1'], model))"
]
},
{
"cell_type": "markdown",
"id": "bc327b59-ed81-478c-ad99-d507c0a54f65",
"metadata": {},
"source": [
"Mean Average Precision (MAP) is a metric that has a long history in Information Retrieval research. It is called Mean Average Precision, because you calculate the Average Precision (AP) of each query, and then calculate the mean of the AP scores across your query set. I know, it's silly.\n",
"\n",
"So, what's Average Precision?"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "d68485c9-879a-46c4-bd37-8494ff8e42c4",
"metadata": {},
"outputs": [],
"source": [
"def average_precision(gold, retrieved, verbose=False):\n",
" tp = 0\n",
" sum_p = 0\n",
" p_at_rank = 0\n",
" for rank, item in enumerate(retrieved, 1):\n",
" if item in gold:\n",
" tp += 1\n",
" # precision: true_positives / all_retrieved\n",
" # i.e. stuff you got right / all the stuff you got\n",
" p_at_rank = tp/rank\n",
" \n",
" sum_p += p_at_rank\n",
" if verbose:\n",
" print('Precision at rank', rank, '=', p_at_rank, '| Sum so far =', sum_p)\n",
" \n",
" if verbose:\n",
" print(sum_p, '/', len(gold))\n",
" return sum_p/len(gold)\n"
]
},
{
"cell_type": "markdown",
"id": "ba8b67e1-33b7-4eb7-880d-56abd1a906b9",
"metadata": {},
"source": [
"At each point in your retrieved list where there is a true positive item, you calculate the precision at that rank. P = TP/(TP+FN) \n",
"\n",
"The precisions are summed, then divided by the total number of positive items you could have retrieved. This is equivalent to calculating the area under the precision-recall curve."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cb2acafa-a4eb-4eaf-aa69-152bba9dc950",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average Precision of ground truth: [1, 2]\n",
"[0, 1, 0, 2, 0, 0, 0, 0, 0, 0] =>\n",
"Precision at rank 2 = 0.5 | Sum so far = 0.5\n",
"Precision at rank 4 = 0.5 | Sum so far = 1.0\n",
"1.0 / 2\n",
"0.5\n"
]
}
],
"source": [
"print(\"Average Precision of ground truth:\", ground_truth['q1'])\n",
"print(model1['q1'], '=>')\n",
"print(average_precision(ground_truth['q1'], model1['q1'], verbose=True))"
]
},
{
"cell_type": "markdown",
"id": "7df67054-8134-448a-9a63-e7f81b261ad6",
"metadata": {},
"source": [
"A desirable property of AP is that ranks matter!"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "6899c3c7-8e6e-4538-9e32-31b6e5fe8112",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average Precision of ground truth: [1, 2]\n",
"[1, 2, 0, 0, 0, 0, 0, 0, 0, 0] =>\n",
"Precision at rank 1 = 1.0 | Sum so far = 1.0\n",
"Precision at rank 2 = 1.0 | Sum so far = 2.0\n",
"2.0 / 2\n",
"1.0\n",
"\n",
"[0, 0, 0, 0, 0, 0, 0, 0, 1, 2] =>\n",
"Precision at rank 9 = 0.1111111111111111 | Sum so far = 0.1111111111111111\n",
"Precision at rank 10 = 0.2 | Sum so far = 0.3111111111111111\n",
"0.3111111111111111 / 2\n",
"0.15555555555555556\n"
]
}
],
"source": [
"print(\"Average Precision of ground truth:\", ground_truth['q1'])\n",
"print(model2['q1'], '=>')\n",
"print(average_precision(ground_truth['q1'], model2['q1'], verbose=True))\n",
"print()\n",
"print(model3['q1'], '=>')\n",
"print(average_precision(ground_truth['q1'], model3['q1'], verbose=True))"
]
},
{
"cell_type": "markdown",
"id": "fcd11f07-9c84-4314-87e6-e558848e5820",
"metadata": {
"tags": []
},
"source": [
"Another desirable property of AP is that the top ranks of the retrieved result matter more than the bottom ranks. Note that the difference of getting the right item in rank 1 vs. 2 makes a 0.25 difference in score, while the difference for rank 9 vs. 10 is ~0.055."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "27b52df1-0b67-48ee-ad4f-5755976c3713",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average Precision of ground truth: [1, 2]\n",
"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0] => 0.5\n",
"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0] => 0.25\n",
"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0] => 0.16666666666666666\n",
"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0] => 0.125\n",
"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0] => 0.1\n",
"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0] => 0.08333333333333333\n",
"[0, 0, 0, 0, 0, 0, 1, 0, 0, 0] => 0.07142857142857142\n",
"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0] => 0.0625\n",
"[0, 0, 0, 0, 0, 0, 0, 0, 1, 0] => 0.05555555555555555\n",
"[0, 0, 0, 0, 0, 0, 0, 0, 0, 1] => 0.05\n"
]
}
],
"source": [
"print(\"Average Precision of ground truth:\", ground_truth['q1'])\n",
"\n",
"for i in range(10):\n",
" model = [0]*i + [1] + [0]*(9-i)\n",
" print(model, '=>', average_precision(ground_truth['q1'], model))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2a163929-727a-40be-b2b5-995091471097",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Average Precision of ground truth: [1, 2]\n",
"[1, 2, 0, 0, 0, 0, 0, 0, 0, 0] => 1.0 | delta%= 0.0\n",
"[1, 0, 2, 0, 0, 0, 0, 0, 0, 0] => 0.833 | delta%= -16.7\n",
"[1, 0, 0, 2, 0, 0, 0, 0, 0, 0] => 0.75 | delta%= -10.0\n",
"[1, 0, 0, 0, 2, 0, 0, 0, 0, 0] => 0.7 | delta%= -6.7\n",
"[1, 0, 0, 0, 0, 2, 0, 0, 0, 0] => 0.667 | delta%= -4.8\n",
"[1, 0, 0, 0, 0, 0, 2, 0, 0, 0] => 0.643 | delta%= -3.6\n",
"[1, 0, 0, 0, 0, 0, 0, 2, 0, 0] => 0.625 | delta%= -2.8\n",
"[1, 0, 0, 0, 0, 0, 0, 0, 2, 0] => 0.611 | delta%= -2.2\n",
"[1, 0, 0, 0, 0, 0, 0, 0, 0, 2] => 0.6 | delta%= -1.8\n",
"[0, 1, 2, 0, 0, 0, 0, 0, 0, 0] => 0.583 | delta%= -2.8\n",
"[0, 1, 0, 2, 0, 0, 0, 0, 0, 0] => 0.5 | delta%= -14.3\n",
"[0, 1, 0, 0, 2, 0, 0, 0, 0, 0] => 0.45 | delta%= -10.0\n",
"[0, 1, 0, 0, 0, 2, 0, 0, 0, 0] => 0.417 | delta%= -7.4\n",
"[0, 1, 0, 0, 0, 0, 2, 0, 0, 0] => 0.393 | delta%= -5.7\n",
"[0, 1, 0, 0, 0, 0, 0, 2, 0, 0] => 0.375 | delta%= -4.5\n",
"[0, 1, 0, 0, 0, 0, 0, 0, 2, 0] => 0.361 | delta%= -3.7\n",
"[0, 1, 0, 0, 0, 0, 0, 0, 0, 2] => 0.35 | delta%= -3.1\n",
"[0, 0, 1, 2, 0, 0, 0, 0, 0, 0] => 0.417 | delta%= 19.0\n",
"[0, 0, 1, 0, 2, 0, 0, 0, 0, 0] => 0.367 | delta%= -12.0\n",
"[0, 0, 1, 0, 0, 2, 0, 0, 0, 0] => 0.333 | delta%= -9.1\n",
"[0, 0, 1, 0, 0, 0, 2, 0, 0, 0] => 0.31 | delta%= -7.1\n",
"[0, 0, 1, 0, 0, 0, 0, 2, 0, 0] => 0.292 | delta%= -5.8\n",
"[0, 0, 1, 0, 0, 0, 0, 0, 2, 0] => 0.278 | delta%= -4.8\n",
"[0, 0, 1, 0, 0, 0, 0, 0, 0, 2] => 0.267 | delta%= -4.0\n",
"[0, 0, 0, 1, 2, 0, 0, 0, 0, 0] => 0.325 | delta%= 21.9\n",
"[0, 0, 0, 1, 0, 2, 0, 0, 0, 0] => 0.292 | delta%= -10.3\n",
"[0, 0, 0, 1, 0, 0, 2, 0, 0, 0] => 0.268 | delta%= -8.2\n",
"[0, 0, 0, 1, 0, 0, 0, 2, 0, 0] => 0.25 | delta%= -6.7\n",
"[0, 0, 0, 1, 0, 0, 0, 0, 2, 0] => 0.236 | delta%= -5.6\n",
"[0, 0, 0, 1, 0, 0, 0, 0, 0, 2] => 0.225 | delta%= -4.7\n",
"[0, 0, 0, 0, 1, 2, 0, 0, 0, 0] => 0.267 | delta%= 18.5\n",
"[0, 0, 0, 0, 1, 0, 2, 0, 0, 0] => 0.243 | delta%= -8.9\n",
"[0, 0, 0, 0, 1, 0, 0, 2, 0, 0] => 0.225 | delta%= -7.4\n",
"[0, 0, 0, 0, 1, 0, 0, 0, 2, 0] => 0.211 | delta%= -6.2\n",
"[0, 0, 0, 0, 1, 0, 0, 0, 0, 2] => 0.2 | delta%= -5.3\n",
"[0, 0, 0, 0, 0, 1, 2, 0, 0, 0] => 0.226 | delta%= 13.1\n",
"[0, 0, 0, 0, 0, 1, 0, 2, 0, 0] => 0.208 | delta%= -7.9\n",
"[0, 0, 0, 0, 0, 1, 0, 0, 2, 0] => 0.194 | delta%= -6.7\n",
"[0, 0, 0, 0, 0, 1, 0, 0, 0, 2] => 0.183 | delta%= -5.7\n",
"[0, 0, 0, 0, 0, 0, 1, 2, 0, 0] => 0.196 | delta%= 7.1\n",
"[0, 0, 0, 0, 0, 0, 1, 0, 2, 0] => 0.183 | delta%= -7.1\n",
"[0, 0, 0, 0, 0, 0, 1, 0, 0, 2] => 0.171 | delta%= -6.1\n",
"[0, 0, 0, 0, 0, 0, 0, 1, 2, 0] => 0.174 | delta%= 1.3\n",
"[0, 0, 0, 0, 0, 0, 0, 1, 0, 2] => 0.163 | delta%= -6.4\n",
"[0, 0, 0, 0, 0, 0, 0, 0, 1, 2] => 0.156 | delta%= -4.3\n"
]
}
],
"source": [
"print()\n",
"print(\"Average Precision of ground truth:\", ground_truth['q1'])\n",
"prev_score = 1\n",
"for i in range(9):\n",
" for j in range(i,9):\n",
" model = [0]*i + [1] + [0]*(j-i) + [2] + [0]*(8-j)\n",
" score = average_precision(ground_truth['q1'], model)\n",
" print(model, '=>', round(score, 3), ' | delta%= ', round((score-prev_score)/prev_score*100, 1))\n",
" prev_score = score"
]
},
{
"cell_type": "markdown",
"id": "f0f49bfc-ab0f-4372-b305-b48642c2e6f6",
"metadata": {},
"source": [
"Finally, MAP is useful because it is a _stable_ metric. That is, if system A is better than system B in MAP@100, A will likely better than B in MAP@1000, as well. This is often _not_ the case with recall at different cut-offs. Generating examples for this is left as an exercise for the reader. "
]
},
{
"cell_type": "markdown",
"id": "7cebe78b-30a4-43e8-8b0f-14833a3644bd",
"metadata": {},
"source": [
"## MAP vs NDCG"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "ee7bf95a-e926-4086-9636-bca1fcb05eb3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ground truth: [1, 2]\n",
"AP [1, 1, 0, 0, 0, 0, 0, 0, 0, 0] => 1.0\n",
"NDCG [1, 1, 0, 0, 0, 0, 0, 0, 0, 0] => 1.0\n",
"--\n",
"AP [1, 0, 1, 0, 0, 0, 0, 0, 0, 0] => 0.8333333333333333\n",
"NDCG [1, 0, 1, 0, 0, 0, 0, 0, 0, 0] => 0.9197207891481877\n",
"--\n",
"AP [1, 0, 0, 1, 0, 0, 0, 0, 0, 0] => 0.75\n",
"NDCG [1, 0, 0, 1, 0, 0, 0, 0, 0, 0] => 0.8772153153380493\n",
"--\n",
"AP [1, 0, 0, 0, 1, 0, 0, 0, 0, 0] => 0.7\n",
"NDCG [1, 0, 0, 0, 1, 0, 0, 0, 0, 0] => 0.8503449055347547\n",
"--\n",
"AP [1, 0, 0, 0, 0, 1, 0, 0, 0, 0] => 0.6666666666666666\n",
"NDCG [1, 0, 0, 0, 0, 1, 0, 0, 0, 0] => 0.8315546295836228\n",
"--\n",
"AP [1, 0, 0, 0, 0, 0, 1, 0, 0, 0] => 0.6428571428571428\n",
"NDCG [1, 0, 0, 0, 0, 0, 1, 0, 0, 0] => 0.8175295903539447\n",
"--\n",
"AP [1, 0, 0, 0, 0, 0, 0, 1, 0, 0] => 0.625\n",
"NDCG [1, 0, 0, 0, 0, 0, 0, 1, 0, 0] => 0.8065735963827294\n",
"--\n",
"AP [1, 0, 0, 0, 0, 0, 0, 0, 1, 0] => 0.6111111111111112\n",
"NDCG [1, 0, 0, 0, 0, 0, 0, 0, 1, 0] => 0.7977228895450265\n",
"--\n",
"AP [1, 0, 0, 0, 0, 0, 0, 0, 0, 1] => 0.6\n",
"NDCG [1, 0, 0, 0, 0, 0, 0, 0, 0, 1] => 0.7903864795495061\n",
"--\n",
"AP [0, 1, 1, 0, 0, 0, 0, 0, 0, 0] => 0.5833333333333333\n",
"NDCG [0, 1, 1, 0, 0, 0, 0, 0, 0, 0] => 0.6934264036172708\n",
"--\n",
"AP [0, 1, 0, 1, 0, 0, 0, 0, 0, 0] => 0.5\n",
"NDCG [0, 1, 0, 1, 0, 0, 0, 0, 0, 0] => 0.6509209298071323\n",
"--\n",
"AP [0, 1, 0, 0, 1, 0, 0, 0, 0, 0] => 0.45\n",
"NDCG [0, 1, 0, 0, 1, 0, 0, 0, 0, 0] => 0.6240505200038378\n",
"--\n",
"AP [0, 1, 0, 0, 0, 1, 0, 0, 0, 0] => 0.41666666666666663\n",
"NDCG [0, 1, 0, 0, 0, 1, 0, 0, 0, 0] => 0.6052602440527058\n",
"--\n",
"AP [0, 1, 0, 0, 0, 0, 1, 0, 0, 0] => 0.39285714285714285\n",
"NDCG [0, 1, 0, 0, 0, 0, 1, 0, 0, 0] => 0.5912352048230278\n",
"--\n",
"AP [0, 1, 0, 0, 0, 0, 0, 1, 0, 0] => 0.375\n",
"NDCG [0, 1, 0, 0, 0, 0, 0, 1, 0, 0] => 0.5802792108518124\n",
"--\n",
"AP [0, 1, 0, 0, 0, 0, 0, 0, 1, 0] => 0.3611111111111111\n",
"NDCG [0, 1, 0, 0, 0, 0, 0, 0, 1, 0] => 0.5714285040141095\n",
"--\n",
"AP [0, 1, 0, 0, 0, 0, 0, 0, 0, 1] => 0.35\n",
"NDCG [0, 1, 0, 0, 0, 0, 0, 0, 0, 1] => 0.5640920940185892\n",
"--\n",
"AP [0, 0, 1, 1, 0, 0, 0, 0, 0, 0] => 0.41666666666666663\n",
"NDCG [0, 0, 1, 1, 0, 0, 0, 0, 0, 0] => 0.57064171895532\n",
"--\n",
"AP [0, 0, 1, 0, 1, 0, 0, 0, 0, 0] => 0.3666666666666667\n",
"NDCG [0, 0, 1, 0, 1, 0, 0, 0, 0, 0] => 0.5437713091520254\n",
"--\n",
"AP [0, 0, 1, 0, 0, 1, 0, 0, 0, 0] => 0.3333333333333333\n",
"NDCG [0, 0, 1, 0, 0, 1, 0, 0, 0, 0] => 0.5249810332008935\n",
"--\n",
"AP [0, 0, 1, 0, 0, 0, 1, 0, 0, 0] => 0.30952380952380953\n",
"NDCG [0, 0, 1, 0, 0, 0, 1, 0, 0, 0] => 0.5109559939712155\n",
"--\n",
"AP [0, 0, 1, 0, 0, 0, 0, 1, 0, 0] => 0.29166666666666663\n",
"NDCG [0, 0, 1, 0, 0, 0, 0, 1, 0, 0] => 0.5000000000000001\n",
"--\n",
"AP [0, 0, 1, 0, 0, 0, 0, 0, 1, 0] => 0.2777777777777778\n",
"NDCG [0, 0, 1, 0, 0, 0, 0, 0, 1, 0] => 0.4911492931622972\n",
"--\n",
"AP [0, 0, 1, 0, 0, 0, 0, 0, 0, 1] => 0.26666666666666666\n",
"NDCG [0, 0, 1, 0, 0, 0, 0, 0, 0, 1] => 0.48381288316677684\n",
"--\n",
"AP [0, 0, 0, 1, 1, 0, 0, 0, 0, 0] => 0.325\n",
"NDCG [0, 0, 0, 1, 1, 0, 0, 0, 0, 0] => 0.5012658353418871\n",
"--\n",
"AP [0, 0, 0, 1, 0, 1, 0, 0, 0, 0] => 0.29166666666666663\n",
"NDCG [0, 0, 0, 1, 0, 1, 0, 0, 0, 0] => 0.4824755593907551\n",
"--\n",
"AP [0, 0, 0, 1, 0, 0, 1, 0, 0, 0] => 0.26785714285714285\n",
"NDCG [0, 0, 0, 1, 0, 0, 1, 0, 0, 0] => 0.468450520161077\n",
"--\n",
"AP [0, 0, 0, 1, 0, 0, 0, 1, 0, 0] => 0.25\n",
"NDCG [0, 0, 0, 1, 0, 0, 0, 1, 0, 0] => 0.4574945261898617\n",
"--\n",
"AP [0, 0, 0, 1, 0, 0, 0, 0, 1, 0] => 0.2361111111111111\n",
"NDCG [0, 0, 0, 1, 0, 0, 0, 0, 1, 0] => 0.44864381935215875\n",
"--\n",
"AP [0, 0, 0, 1, 0, 0, 0, 0, 0, 1] => 0.225\n",
"NDCG [0, 0, 0, 1, 0, 0, 0, 0, 0, 1] => 0.4413074093566384\n",
"--\n",
"AP [0, 0, 0, 0, 1, 1, 0, 0, 0, 0] => 0.26666666666666666\n",
"NDCG [0, 0, 0, 0, 1, 1, 0, 0, 0, 0] => 0.45560514958746057\n",
"--\n",
"AP [0, 0, 0, 0, 1, 0, 1, 0, 0, 0] => 0.24285714285714285\n",
"NDCG [0, 0, 0, 0, 1, 0, 1, 0, 0, 0] => 0.4415801103577825\n",
"--\n",
"AP [0, 0, 0, 0, 1, 0, 0, 1, 0, 0] => 0.225\n",
"NDCG [0, 0, 0, 0, 1, 0, 0, 1, 0, 0] => 0.43062411638656717\n",
"--\n",
"AP [0, 0, 0, 0, 1, 0, 0, 0, 1, 0] => 0.2111111111111111\n",
"NDCG [0, 0, 0, 0, 1, 0, 0, 0, 1, 0] => 0.4217734095488642\n",
"--\n",
"AP [0, 0, 0, 0, 1, 0, 0, 0, 0, 1] => 0.2\n",
"NDCG [0, 0, 0, 0, 1, 0, 0, 0, 0, 1] => 0.4144369995533439\n",
"--\n",
"AP [0, 0, 0, 0, 0, 1, 1, 0, 0, 0] => 0.22619047619047616\n",
"NDCG [0, 0, 0, 0, 0, 1, 1, 0, 0, 0] => 0.4227898344066506\n",
"--\n",
"AP [0, 0, 0, 0, 0, 1, 0, 1, 0, 0] => 0.20833333333333331\n",
"NDCG [0, 0, 0, 0, 0, 1, 0, 1, 0, 0] => 0.4118338404354352\n",
"--\n",
"AP [0, 0, 0, 0, 0, 1, 0, 0, 1, 0] => 0.19444444444444442\n",
"NDCG [0, 0, 0, 0, 0, 1, 0, 0, 1, 0] => 0.4029831335977323\n",
"--\n",
"AP [0, 0, 0, 0, 0, 1, 0, 0, 0, 1] => 0.18333333333333335\n",
"NDCG [0, 0, 0, 0, 0, 1, 0, 0, 0, 1] => 0.3956467236022119\n",
"--\n",
"AP [0, 0, 0, 0, 0, 0, 1, 1, 0, 0] => 0.19642857142857142\n",
"NDCG [0, 0, 0, 0, 0, 0, 1, 1, 0, 0] => 0.3978088012057571\n",
"--\n",
"AP [0, 0, 0, 0, 0, 0, 1, 0, 1, 0] => 0.18253968253968253\n",
"NDCG [0, 0, 0, 0, 0, 0, 1, 0, 1, 0] => 0.38895809436805423\n",
"--\n",
"AP [0, 0, 0, 0, 0, 0, 1, 0, 0, 1] => 0.17142857142857143\n",
"NDCG [0, 0, 0, 0, 0, 0, 1, 0, 0, 1] => 0.38162168437253385\n",
"--\n",
"AP [0, 0, 0, 0, 0, 0, 0, 1, 1, 0] => 0.1736111111111111\n",
"NDCG [0, 0, 0, 0, 0, 0, 0, 1, 1, 0] => 0.37800210039683885\n",
"--\n",
"AP [0, 0, 0, 0, 0, 0, 0, 1, 0, 1] => 0.1625\n",
"NDCG [0, 0, 0, 0, 0, 0, 0, 1, 0, 1] => 0.3706656904013185\n",
"--\n",
"AP [0, 0, 0, 0, 0, 0, 0, 0, 1, 1] => 0.15555555555555556\n",
"NDCG [0, 0, 0, 0, 0, 0, 0, 0, 1, 1] => 0.3618149835636156\n",
"--\n"
]
}
],
"source": [
"\n",
"import numpy as np\n",
"from sklearn.metrics import ndcg_score\n",
"\n",
"true_scores = np.asarray([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])\n",
"prev_score = 1\n",
"print(\"Ground truth:\", ground_truth['q1'])\n",
"for i in range(9):\n",
" for j in range(i,9):\n",
" model = [0]*i + [1] + [0]*(j-i) + [1] + [0]*(8-j)\n",
" \n",
" score_pairs = list(enumerate(model))\n",
" score_pairs.sort(key=lambda x: x[1], reverse=True)\n",
" pred_scores = np.asarray([[-rank for rank, score in score_pairs]])\n",
" \n",
" \n",
" print('AP ', model, '=>', average_precision(ground_truth['q1'], model))\n",
" print('NDCG', model, '=>', ndcg_score(true_scores, pred_scores))\n",
" print('--')\n",
" \n",
" score = ndcg_score(true_scores, pred_scores)\n",
" #print(model, '=>', round(score, 3), ' | delta = ', round(prev_score-score, 3))\n",
" #print(model, '=>', round(score, 3), ' | delta%= ', round((score-prev_score)/prev_score*100, 1))\n",
" prev_score = score"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment