Created
January 4, 2021 07:21
-
-
Save nyk510/cb565d01a834b6d86f8d582d20efdffb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: scikit-learn==0.24.0 in /home/penguin/.conda/lib/python3.7/site-packages (0.24.0)\n", | |
"Requirement already satisfied: joblib>=0.11 in /home/penguin/.conda/lib/python3.7/site-packages (from scikit-learn==0.24.0) (0.16.0)\n", | |
"Requirement already satisfied: numpy>=1.13.3 in /home/penguin/.conda/lib/python3.7/site-packages (from scikit-learn==0.24.0) (1.18.5)\n", | |
"Requirement already satisfied: threadpoolctl>=2.0.0 in /home/penguin/.conda/lib/python3.7/site-packages (from scikit-learn==0.24.0) (2.1.0)\n", | |
"Requirement already satisfied: scipy>=0.19.1 in /home/penguin/.conda/lib/python3.7/site-packages (from scikit-learn==0.24.0) (1.5.0)\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install -U scikit-learn==0.24.0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.linear_model import LogisticRegression\n", | |
"from sklearn.model_selection import GridSearchCV\n", | |
"from sklearn.model_selection import KFold\n", | |
"\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### BaseSearchCVと sample weight の問題点\n", | |
"\n", | |
"* [`BaseSearchCV`](https://github.com/scikit-learn/scikit-learn/blob/0.20.2/sklearn/model_selection/_search.py#L409) を継承した CV class において, 各 CV ごとの学習自体は `sample_weight` が適用される\n", | |
"* しかし validation set に対する score の計算では `sample_weight` が適用されない\n", | |
" * 具体的には https://github.com/scikit-learn/scikit-learn/blob/0.24.0/sklearn/model_selection/_validation.py#L620 このあたり\n", | |
" * 例えば pos / negative にそれぞれ 1 / 1000 の重みを与えると本来的の重み付きスコアは 0.999 になっていてほしい。\n", | |
" * だが実際には重みが適用されないので 0.5 のままになる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"cv = np.array([0, 0, 1, 1])\n", | |
"X = np.ones(shape=(4, 1))\n", | |
"y = np.array([1, 0, 1, 0])\n", | |
"\n", | |
"fold = np.array([\n", | |
" [[0, 1], [2, 3]],\n", | |
" [[2, 3], [0, 1]],\n", | |
"])\n", | |
"\n", | |
"# negative に対して weight を 999 / 1000 で与える\n", | |
"sample_weight_for_zeros = [\n", | |
" 1, 999, 1, 999\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"GridSearchCV(cv=array([[[0, 1],\n", | |
" [2, 3]],\n", | |
"\n", | |
" [[2, 3],\n", | |
" [0, 1]]]),\n", | |
" estimator=LogisticRegression(), param_grid={'random_state': [42]},\n", | |
" return_train_score=True, scoring='accuracy')" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grid = GridSearchCV(\n", | |
" estimator=LogisticRegression(), \n", | |
" param_grid={ 'random_state': [42] }, \n", | |
" scoring='accuracy', \n", | |
" cv=fold, \n", | |
" return_train_score=True\n", | |
")\n", | |
"grid.fit(X, y, sample_weight=sample_weight_for_zeros)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.5" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# 本当は 0.999 になってほしい\n", | |
"grid.best_score_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([[0.99900002, 0.00099998],\n", | |
" [0.99900002, 0.00099998],\n", | |
" [0.99900002, 0.00099998],\n", | |
" [0.99900002, 0.00099998]])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# 予測値は negative が 0.999 が出力される (学習自体には sample_weight が適用されているため.)\n", | |
"grid.predict_proba(X)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### まとめ\n", | |
"\n", | |
"* 学習自体は sample_weight つきで期待通り実行されるが, 算出される score は sample_weight を無視して計算されている.\n", | |
"* GridSearchCV や RandomSaerchCV など, BaseSearchCV を継承したパラメータサーチでは score の意味で最も良いパラメータが選ばれる. \n", | |
"* したがって, sample_weight を考慮して最も良い parameter が知りたい場合でも, sample_weight を無視したスコアの意味でもっとも良いモデルが選ばれるため、問題." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment