Last active
October 7, 2022 22:01
-
-
Save yubink/0de58a46de56c3aad6dba1fe21cc033c to your computer and use it in GitHub Desktop.
A brief tutorial on MAP (mean average precision)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "2030ff9d-7ebd-45c1-8310-746fd4b9ac3f", | |
"metadata": {}, | |
"source": [ | |
"Recall is the measure that says, how many things you were suppose to find did you actually find?\n", | |
"\n", | |
"That is, recall = true_positive/(true_positive + false_negative)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 106, | |
"id": "edba6a06", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"\n", | |
"# recall: true_positives / (true_positives + false_negatives)\n", | |
"# i.e. stuff you got right / all the right stuff\n", | |
"def recall(gold, retrieved):\n", | |
" tp = 0\n", | |
" for item in retrieved:\n", | |
" if item in gold:\n", | |
" tp += 1\n", | |
" return tp/len(gold)\n", | |
" #one liner: len([x for x in retrieved if x in gold])/len(gold)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 130, | |
"id": "0ab61af0-20a1-4a9f-ad3a-23e1e1fa9b2d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Recall of ground truth: [1, 2]\n", | |
"[0, 0, 0, 2, 0, 0, 0, 0, 0, 0] => 0.5\n", | |
"[0, 1, 0, 2, 0, 0, 0, 0, 0, 0] => 1.0\n" | |
] | |
} | |
], | |
"source": [ | |
"ground_truth = {'q1': [1, 2], 'q2': [1]}\n", | |
"\n", | |
"model0 = {'q1': [0, 0, 0, 2, 0, 0, 0, 0, 0, 0]}\n", | |
"\n", | |
"print(\"Recall of ground truth:\", ground_truth['q1'])\n", | |
"print(model0['q1'], '=>', recall(ground_truth['q1'], model0['q1']))\n", | |
"\n", | |
"model1 = {'q1': [0, 1, 0, 2, 0, 0, 0, 0, 0, 0]}\n", | |
"print(model1['q1'], '=>', recall(ground_truth['q1'], model1['q1']))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1cdb6934-0283-4f60-bb02-70a896a8c7b5", | |
"metadata": {}, | |
"source": [ | |
"One thing to note is that recall is not rank-sensitive. It doesn't care in what rank you put the true positive items in." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 128, | |
"id": "3faec046-284c-43b1-b8a9-a224896e844b", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Recall of ground truth: [1, 2]\n", | |
"[1, 2, 0, 0, 0, 0, 0, 0, 0, 0] => 1.0\n", | |
"[0, 0, 0, 0, 0, 0, 0, 0, 1, 2] => 1.0\n" | |
] | |
} | |
], | |
"source": [ | |
"model2 = {'q1': [1, 2, 0, 0, 0, 0, 0, 0, 0, 0]}\n", | |
"print(\"Recall of ground truth:\", ground_truth['q1'])\n", | |
"print(model2['q1'], '=>', recall(ground_truth['q1'], model2['q1']))\n", | |
"\n", | |
"model3 = {'q1': [0, 0, 0, 0, 0, 0, 0, 0, 1, 2]}\n", | |
"print(model3['q1'], '=>', recall(ground_truth['q1'], model3['q1']))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"id": "56c7941c-6122-4645-94f2-003c99fa8d1f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Recall of ground truth: [1, 2]\n", | |
"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0] => 0.5\n", | |
"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0] => 0.5\n", | |
"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0] => 0.5\n", | |
"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0] => 0.5\n", | |
"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0] => 0.5\n", | |
"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0] => 0.5\n", | |
"[0, 0, 0, 0, 0, 0, 1, 0, 0, 0] => 0.5\n", | |
"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0] => 0.5\n", | |
"[0, 0, 0, 0, 0, 0, 0, 0, 1, 0] => 0.5\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Recall of ground truth:\", ground_truth['q1'])\n", | |
"\n", | |
"for i in range(9):\n", | |
" model = [0]*i + [1] + [0]*(9-i)\n", | |
" print(model, '=>', recall(ground_truth['q1'], model))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bc327b59-ed81-478c-ad99-d507c0a54f65", | |
"metadata": {}, | |
"source": [ | |
"Mean Average Precision (MAP) is a metric that has a long history in Information Retrieval research. It is called Mean Average Precision, because you calculate the Average Precision (AP) of each query, and then calculate the mean of the AP scores across your query set. I know, it's silly.\n", | |
"\n", | |
"So, what's Average Precision?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 127, | |
"id": "d68485c9-879a-46c4-bd37-8494ff8e42c4", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def average_precision(gold, retrieved, verbose=False):\n", | |
" tp = 0\n", | |
" sum_p = 0\n", | |
" p_at_rank = 0\n", | |
" for rank, item in enumerate(retrieved, 1):\n", | |
" if item in gold:\n", | |
" tp += 1\n", | |
" # precision: true_positives / all_retrieved\n", | |
" # i.e. stuff you got right / all the stuff you got\n", | |
" p_at_rank = tp/rank\n", | |
" \n", | |
" sum_p += p_at_rank\n", | |
" if verbose:\n", | |
" print('Precision at rank', rank, '=', p_at_rank, '| Sum so far =', sum_p)\n", | |
" \n", | |
" if verbose:\n", | |
" print(sum_p, '/', len(gold))\n", | |
" return sum_p/len(gold)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ba8b67e1-33b7-4eb7-880d-56abd1a906b9", | |
"metadata": {}, | |
"source": [ | |
"At each point in your retrieved list where there is a true positive item, you calculate the precision at that rank. P = TP/(TP+FN) \n", | |
"\n", | |
"The precisions are summed, then divided by the total number of positive items you could have retrieved. This is equivalent to calculating the area under the precision-recall curve." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 109, | |
"id": "cb2acafa-a4eb-4eaf-aa69-152bba9dc950", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Average Precision of ground truth: [1, 2]\n", | |
"[0, 1, 0, 2, 0, 0, 0, 0, 0, 0] =>\n", | |
"Precision at rank 2 = 0.5 | Sum so far = 0.5\n", | |
"Precision at rank 4 = 0.5 | Sum so far = 1.0\n", | |
"1.0 / 2\n", | |
"0.5\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Average Precision of ground truth:\", ground_truth['q1'])\n", | |
"print(model1['q1'], '=>')\n", | |
"print(average_precision(ground_truth['q1'], model1['q1'], verbose=True))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7df67054-8134-448a-9a63-e7f81b261ad6", | |
"metadata": {}, | |
"source": [ | |
"A desirable property of AP is that ranks matter!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 114, | |
"id": "6899c3c7-8e6e-4538-9e32-31b6e5fe8112", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Average Precision of ground truth: [1, 2]\n", | |
"[1, 2, 0, 0, 0, 0, 0, 0, 0, 0] =>\n", | |
"Precision at rank 1 = 1.0 | Sum so far = 1.0\n", | |
"Precision at rank 2 = 1.0 | Sum so far = 2.0\n", | |
"2.0 / 2\n", | |
"1.0\n", | |
"\n", | |
"[0, 0, 0, 0, 0, 0, 0, 0, 1, 2] =>\n", | |
"Precision at rank 9 = 0.1111111111111111 | Sum so far = 0.1111111111111111\n", | |
"Precision at rank 10 = 0.2 | Sum so far = 0.3111111111111111\n", | |
"0.3111111111111111 / 2\n", | |
"0.15555555555555556\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Average Precision of ground truth:\", ground_truth['q1'])\n", | |
"print(model2['q1'], '=>')\n", | |
"print(average_precision(ground_truth['q1'], model2['q1'], verbose=True))\n", | |
"print()\n", | |
"print(model3['q1'], '=>')\n", | |
"print(average_precision(ground_truth['q1'], model3['q1'], verbose=True))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "fcd11f07-9c84-4314-87e6-e558848e5820", | |
"metadata": { | |
"tags": [] | |
}, | |
"source": [ | |
"Another desirable property of AP is that the top ranks of the retrieved result matter more than the bottom ranks. Note that the difference of getting the right item in rank 1 vs. 2 makes a 0.25 difference in score, while the difference for rank 9 vs. 10 is ~0.055." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 123, | |
"id": "27b52df1-0b67-48ee-ad4f-5755976c3713", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Average Precision of ground truth: [1, 2]\n", | |
"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0] => 0.5\n", | |
"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0] => 0.25\n", | |
"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0] => 0.16666666666666666\n", | |
"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0] => 0.125\n", | |
"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0] => 0.1\n", | |
"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0] => 0.08333333333333333\n", | |
"[0, 0, 0, 0, 0, 0, 1, 0, 0, 0] => 0.07142857142857142\n", | |
"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0] => 0.0625\n", | |
"[0, 0, 0, 0, 0, 0, 0, 0, 1, 0] => 0.05555555555555555\n", | |
"[0, 0, 0, 0, 0, 0, 0, 0, 0, 1] => 0.05\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Average Precision of ground truth:\", ground_truth['q1'])\n", | |
"\n", | |
"for i in range(10):\n", | |
" model = [0]*i + [1] + [0]*(9-i)\n", | |
" print(model, '=>', average_precision(ground_truth['q1'], model))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 167, | |
"id": "2a163929-727a-40be-b2b5-995091471097", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Average Precision of ground truth: [1, 2]\n", | |
"[1, 2, 0, 0, 0, 0, 0, 0, 0, 0] => 1.0 | delta%= 0.0\n", | |
"[1, 0, 2, 0, 0, 0, 0, 0, 0, 0] => 0.833 | delta%= -16.7\n", | |
"[1, 0, 0, 2, 0, 0, 0, 0, 0, 0] => 0.75 | delta%= -10.0\n", | |
"[1, 0, 0, 0, 2, 0, 0, 0, 0, 0] => 0.7 | delta%= -6.7\n", | |
"[1, 0, 0, 0, 0, 2, 0, 0, 0, 0] => 0.667 | delta%= -4.8\n", | |
"[1, 0, 0, 0, 0, 0, 2, 0, 0, 0] => 0.643 | delta%= -3.6\n", | |
"[1, 0, 0, 0, 0, 0, 0, 2, 0, 0] => 0.625 | delta%= -2.8\n", | |
"[1, 0, 0, 0, 0, 0, 0, 0, 2, 0] => 0.611 | delta%= -2.2\n", | |
"[1, 0, 0, 0, 0, 0, 0, 0, 0, 2] => 0.6 | delta%= -1.8\n", | |
"[0, 1, 2, 0, 0, 0, 0, 0, 0, 0] => 0.583 | delta%= -2.8\n", | |
"[0, 1, 0, 2, 0, 0, 0, 0, 0, 0] => 0.5 | delta%= -14.3\n", | |
"[0, 1, 0, 0, 2, 0, 0, 0, 0, 0] => 0.45 | delta%= -10.0\n", | |
"[0, 1, 0, 0, 0, 2, 0, 0, 0, 0] => 0.417 | delta%= -7.4\n", | |
"[0, 1, 0, 0, 0, 0, 2, 0, 0, 0] => 0.393 | delta%= -5.7\n", | |
"[0, 1, 0, 0, 0, 0, 0, 2, 0, 0] => 0.375 | delta%= -4.5\n", | |
"[0, 1, 0, 0, 0, 0, 0, 0, 2, 0] => 0.361 | delta%= -3.7\n", | |
"[0, 1, 0, 0, 0, 0, 0, 0, 0, 2] => 0.35 | delta%= -3.1\n", | |
"[0, 0, 1, 2, 0, 0, 0, 0, 0, 0] => 0.417 | delta%= 19.0\n", | |
"[0, 0, 1, 0, 2, 0, 0, 0, 0, 0] => 0.367 | delta%= -12.0\n", | |
"[0, 0, 1, 0, 0, 2, 0, 0, 0, 0] => 0.333 | delta%= -9.1\n", | |
"[0, 0, 1, 0, 0, 0, 2, 0, 0, 0] => 0.31 | delta%= -7.1\n", | |
"[0, 0, 1, 0, 0, 0, 0, 2, 0, 0] => 0.292 | delta%= -5.8\n", | |
"[0, 0, 1, 0, 0, 0, 0, 0, 2, 0] => 0.278 | delta%= -4.8\n", | |
"[0, 0, 1, 0, 0, 0, 0, 0, 0, 2] => 0.267 | delta%= -4.0\n", | |
"[0, 0, 0, 1, 2, 0, 0, 0, 0, 0] => 0.325 | delta%= 21.9\n", | |
"[0, 0, 0, 1, 0, 2, 0, 0, 0, 0] => 0.292 | delta%= -10.3\n", | |
"[0, 0, 0, 1, 0, 0, 2, 0, 0, 0] => 0.268 | delta%= -8.2\n", | |
"[0, 0, 0, 1, 0, 0, 0, 2, 0, 0] => 0.25 | delta%= -6.7\n", | |
"[0, 0, 0, 1, 0, 0, 0, 0, 2, 0] => 0.236 | delta%= -5.6\n", | |
"[0, 0, 0, 1, 0, 0, 0, 0, 0, 2] => 0.225 | delta%= -4.7\n", | |
"[0, 0, 0, 0, 1, 2, 0, 0, 0, 0] => 0.267 | delta%= 18.5\n", | |
"[0, 0, 0, 0, 1, 0, 2, 0, 0, 0] => 0.243 | delta%= -8.9\n", | |
"[0, 0, 0, 0, 1, 0, 0, 2, 0, 0] => 0.225 | delta%= -7.4\n", | |
"[0, 0, 0, 0, 1, 0, 0, 0, 2, 0] => 0.211 | delta%= -6.2\n", | |
"[0, 0, 0, 0, 1, 0, 0, 0, 0, 2] => 0.2 | delta%= -5.3\n", | |
"[0, 0, 0, 0, 0, 1, 2, 0, 0, 0] => 0.226 | delta%= 13.1\n", | |
"[0, 0, 0, 0, 0, 1, 0, 2, 0, 0] => 0.208 | delta%= -7.9\n", | |
"[0, 0, 0, 0, 0, 1, 0, 0, 2, 0] => 0.194 | delta%= -6.7\n", | |
"[0, 0, 0, 0, 0, 1, 0, 0, 0, 2] => 0.183 | delta%= -5.7\n", | |
"[0, 0, 0, 0, 0, 0, 1, 2, 0, 0] => 0.196 | delta%= 7.1\n", | |
"[0, 0, 0, 0, 0, 0, 1, 0, 2, 0] => 0.183 | delta%= -7.1\n", | |
"[0, 0, 0, 0, 0, 0, 1, 0, 0, 2] => 0.171 | delta%= -6.1\n", | |
"[0, 0, 0, 0, 0, 0, 0, 1, 2, 0] => 0.174 | delta%= 1.3\n", | |
"[0, 0, 0, 0, 0, 0, 0, 1, 0, 2] => 0.163 | delta%= -6.4\n", | |
"[0, 0, 0, 0, 0, 0, 0, 0, 1, 2] => 0.156 | delta%= -4.3\n" | |
] | |
} | |
], | |
"source": [ | |
"print()\n", | |
"print(\"Average Precision of ground truth:\", ground_truth['q1'])\n", | |
"prev_score = 1\n", | |
"for i in range(9):\n", | |
" for j in range(i,9):\n", | |
" model = [0]*i + [1] + [0]*(j-i) + [2] + [0]*(8-j)\n", | |
" score = average_precision(ground_truth['q1'], model)\n", | |
" print(model, '=>', round(score, 3), ' | delta%= ', round((score-prev_score)/prev_score*100, 1))\n", | |
" prev_score = score" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bca2b02c-d78a-4a01-8c69-94083ef78d04", | |
"metadata": {}, | |
"source": [ | |
"Finally, MAP is useful because it is a _stable_ metric. That is, if system A is better than system B in MAP@100, A will likely better than B in MAP@1000, as well. This is often _not_ the case with recall at different cut-offs. Generating examples for this is left as an exercise for the reader. " | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment