Last active
November 7, 2020 09:30
-
-
Save goerlitz/fb1230ee820406c5a9681249345fce3f to your computer and use it in GitHub Desktop.
Bug in fastai's accuracy_thresh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**There seems to be a bug in fastai's `metrics.accuracy_thresh()`**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"from fastai.vision import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"fastai 1.0.61\n", | |
"torch 1.5.1\n", | |
"torchvision 0.6.1\n" | |
] | |
} | |
], | |
"source": [ | |
"import fastai\n", | |
"import torch\n", | |
"import torchvision\n", | |
"print('fastai', fastai.__version__)\n", | |
"print('torch ', torch.__version__)\n", | |
"print('torchvision ', torchvision.__version__)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Instead use these accuracy functions to get correct results**\n", | |
"\n", | |
"see also:\n", | |
"* https://forums.fast.ai/t/a-different-variant-of-accuracy-thresh/47977\n", | |
"* https://stats.stackexchange.com/questions/12702/what-are-the-measure-for-accuracy-of-multilabel-data" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def accuracy_multi_exact(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:\n", | |
" \"\"\"Compute accuracy where the predicted labels must match exactly the true labels\"\"\"\n", | |
" if sigmoid: y_pred = y_pred.sigmoid()\n", | |
" return ((y_pred>thresh)==y_true.byte()).all(1).float().mean()\n", | |
"\n", | |
"def accuracy_multi_partial(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:\n", | |
" \"\"\"Compute accuracy with partial match, e.g. one of two labels predicted -> 50% correct.\n", | |
" Mathematically, this is the intersection of predicted and true labels devided by the union of both.\n", | |
" \"\"\"\n", | |
" if sigmoid: y_pred = y_pred.sigmoid()\n", | |
" ypred_byte = (y_pred>thresh).byte()\n", | |
" ytrue_byte = y_true.byte()\n", | |
" return (ypred_byte.bitwise_and(ytrue_byte).sum(1).float() / ypred_byte.bitwise_or(ytrue_byte).sum(1).float()).mean()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Multi-class\n", | |
"\n", | |
"The example data has five classes and two images labeled with the first and second class.\n", | |
"The dummy prediction always predicts the first class.\n", | |
"\n", | |
"Expectation: the accuracy should be 0.5 as only the first image is predicted correctly." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"classes=['ant', 'bee', 'cat', 'dog', 'eel']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>ant</th>\n", | |
" <th>bee</th>\n", | |
" <th>cat</th>\n", | |
" <th>dog</th>\n", | |
" <th>eel</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>img_1</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>img_2</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" ant bee cat dog eel\n", | |
"img_1 1 0 0 0 0\n", | |
"img_2 0 1 0 0 0" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"target = {'img_1': [1, 0, 0, 0, 0], 'img_2': [0, 1, 0, 0, 0]}\n", | |
"df_true = pd.DataFrame.from_dict(target, orient='index', columns=classes)\n", | |
"df_true" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>ant</th>\n", | |
" <th>bee</th>\n", | |
" <th>cat</th>\n", | |
" <th>dog</th>\n", | |
" <th>eel</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>img_1</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>img_2</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" ant bee cat dog eel\n", | |
"img_1 1 0 0 0 0\n", | |
"img_2 1 0 0 0 0" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pred = {'img_1': [1, 0, 0, 0, 0], 'img_2': [1, 0, 0, 0, 0]}\n", | |
"df_pred = pd.DataFrame.from_dict(pred, orient='index', columns=classes)\n", | |
"df_pred" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor([[1., 0., 0., 0., 0.],\n", | |
" [0., 1., 0., 0., 0.]])\n", | |
"tensor([[1., 0., 0., 0., 0.],\n", | |
" [1., 0., 0., 0., 0.]])\n" | |
] | |
} | |
], | |
"source": [ | |
"y_true = tensor(df_true.values).float()\n", | |
"y_pred = tensor(df_pred.values).float()\n", | |
"print(y_true)\n", | |
"print(y_pred)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.8000)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"accuracy_thresh(y_pred, y_true, thresh=0.5, sigmoid=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.5000)" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"accuracy_multi_exact(y_pred, y_true, thresh=0.5, sigmoid=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.5000)" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"accuracy_multi_partial(y_pred, y_true, thresh=0.5, sigmoid=False)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Multi-label\n", | |
"\n", | |
"The example data has five labels and two images each labeled with two of the five labels.\n", | |
"The dummy prediction always predicts the first label.\n", | |
"\n", | |
"Expectation: the accuracy should either be\n", | |
"\n", | |
"* 0.0 because none of images had all labels predicted correctly (exact match), r\n", | |
"* 0.25 because the first image had one of two labels predicted correctly (partial match)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>ant</th>\n", | |
" <th>bee</th>\n", | |
" <th>cat</th>\n", | |
" <th>dog</th>\n", | |
" <th>eel</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>img_1</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>img_2</th>\n", | |
" <td>0</td>\n", | |
" <td>1</td>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" ant bee cat dog eel\n", | |
"img_1 1 0 0 1 0\n", | |
"img_2 0 1 1 0 0" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"target = {'img_1': [1, 0, 0, 1, 0], 'img_2': [0, 1, 1, 0, 0]}\n", | |
"df_true = pd.DataFrame.from_dict(target, orient='index', columns=classes)\n", | |
"df_true" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>ant</th>\n", | |
" <th>bee</th>\n", | |
" <th>cat</th>\n", | |
" <th>dog</th>\n", | |
" <th>eel</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>img_1</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>img_2</th>\n", | |
" <td>1</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" <td>0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" ant bee cat dog eel\n", | |
"img_1 1 0 0 0 0\n", | |
"img_2 1 0 0 0 0" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pred = {'img_1': [1, 0, 0, 0, 0], 'img_2': [1, 0, 0, 0, 0]}\n", | |
"df_pred = pd.DataFrame.from_dict(pred, orient='index', columns=classes)\n", | |
"df_pred" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor([[1., 0., 0., 1., 0.],\n", | |
" [0., 1., 1., 0., 0.]])\n", | |
"tensor([[1., 0., 0., 0., 0.],\n", | |
" [1., 0., 0., 0., 0.]])\n" | |
] | |
} | |
], | |
"source": [ | |
"y_true = tensor(df_true.values).float()\n", | |
"y_pred = tensor(df_pred.values).float()\n", | |
"print(y_true)\n", | |
"print(y_pred)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.6000)" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"accuracy_thresh(y_pred, y_true, thresh=0.5, sigmoid=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.)" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"accuracy_multi_exact(y_pred, y_true, thresh=0.5, sigmoid=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.2500)" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"accuracy_multi_partial(y_pred, y_true, thresh=0.5, sigmoid=False)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment