-
-
Save marymlucas/76ed1e2cbe1b220e70c3ce6ea6fe228c to your computer and use it in GitHub Desktop.
XGBoost Incremental Learning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python2.7/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n", | |
" \"This module will be removed in 0.20.\", DeprecationWarning)\n" | |
] | |
} | |
], | |
"source": [ | |
"# !conda install -yc conda-forge xgboost\n", | |
"import xgboost as xgb\n", | |
"import sklearn.datasets\n", | |
"import sklearn.metrics\n", | |
"import sklearn.feature_selection\n", | |
"import sklearn.feature_extraction\n", | |
"import sklearn.cross_validation\n", | |
"import sklearn.model_selection\n", | |
"import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'0.6'" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"xgb.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['data', 'feature_names', 'DESCR', 'target']\n", | |
"['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO'\n", | |
" 'B' 'LSTAT']\n" | |
] | |
} | |
], | |
"source": [ | |
"df = sklearn.datasets.load_boston()\n", | |
"print(df.keys())\n", | |
"print(df['feature_names'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"X = df['data']\n", | |
"y = df['target']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x_tr, x_te, y_tr, y_te = sklearn.model_selection.train_test_split(X, y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 33.8, 23.7, 20.5, 12.8, 50. , 17.4, 8.8, 17.8, 26.4, 18.2])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_tr[:10]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## One shot learning\n", | |
"Train with all the training data. Only one iteration over the dataset." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"9.7116581444822518" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"one_shot_model = xgb.train({\n", | |
" 'update':'refresh',\n", | |
" 'process_type': 'update',\n", | |
" 'refresh_leaf': True,\n", | |
" 'silent': False,\n", | |
"}, dtrain=xgb.DMatrix(x_tr, y_tr))\n", | |
"y_pr = one_shot_model.predict(xgb.DMatrix(x_te))\n", | |
"sklearn.metrics.mean_squared_error(y_te, y_pr)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"array([ 13.41692734, 10.1039257 , 26.06602859, 14.74896526,\n", | |
" 19.46399117, 22.82827187, 21.09622765, 18.83269501,\n", | |
" 27.70256996, 34.56838226], dtype=float32)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"y_pr[:10]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## One shot iterative training\n", | |
"Exploit the xgb_model parameter of xgb.train to iterate over the training data multiple time" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Iteration 0: 9.71165814448\n", | |
"Iteration 1: 7.90938546712\n", | |
"Iteration 2: 7.83283545287\n", | |
"Iteration 3: 7.90989805123\n", | |
"Iteration 4: 7.93978549112\n" | |
] | |
} | |
], | |
"source": [ | |
"iteration = 5\n", | |
"one_shot_model_itr = None\n", | |
"for i in range(iteration):\n", | |
" one_shot_model_itr = xgb.train({\n", | |
" 'update':'refresh',\n", | |
" 'process_type': 'update',\n", | |
" 'refresh_leaf': True,\n", | |
" 'silent': False,\n", | |
" }, dtrain=xgb.DMatrix(x_tr, y_tr), xgb_model=one_shot_model_itr)\n", | |
" y_pr = one_shot_model_itr.predict(xgb.DMatrix(x_te))\n", | |
" print('Iteration {}: {}'.format(i, sklearn.metrics.mean_squared_error(y_te, y_pr)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"So xgboost models are able to improve when you iterate over data multiple times." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Iterative Incremental Learning" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"MSE itr@0: 239.680186067\n", | |
"MSE itr@1: 111.044669451\n", | |
"MSE itr@2: 57.7185741392\n", | |
"MSE itr@3: 35.7994472176\n", | |
"MSE itr@4: 26.2178656072\n", | |
"MSE itr@5: 20.3012679934\n", | |
"MSE itr@6: 17.0486683066\n", | |
"MSE itr@7: 14.9458533528\n", | |
"MSE itr@8: 13.5863551796\n", | |
"MSE itr@9: 12.5722084078\n", | |
"MSE itr@10: 12.0621747382\n", | |
"MSE itr@11: 11.8287598733\n", | |
"MSE itr@12: 11.6878301253\n", | |
"MSE itr@13: 11.4897400114\n", | |
"MSE itr@14: 11.4627225743\n", | |
"MSE itr@15: 11.5417849176\n", | |
"MSE itr@16: 11.4022054245\n", | |
"MSE itr@17: 11.2675483456\n", | |
"MSE itr@18: 11.3866442707\n", | |
"MSE itr@19: 11.3504530668\n", | |
"MSE itr@20: 11.3818182553\n", | |
"MSE itr@21: 11.5099846894\n", | |
"MSE itr@22: 11.5365974758\n", | |
"MSE itr@23: 11.7541341329\n", | |
"MSE itr@24: 11.9677214525\n", | |
"MSE at the end: 11.9677214525\n" | |
] | |
} | |
], | |
"source": [ | |
"batch_size = 50\n", | |
"iterations = 25\n", | |
"model = None\n", | |
"for i in range(iterations):\n", | |
" for start in range(0, len(x_tr), batch_size):\n", | |
" model = xgb.train({\n", | |
" 'learning_rate': 0.007,\n", | |
" 'update':'refresh',\n", | |
" 'process_type': 'update',\n", | |
" 'refresh_leaf': True,\n", | |
" #'reg_lambda': 3, # L2\n", | |
" 'reg_alpha': 3, # L1\n", | |
" 'silent': False,\n", | |
" }, dtrain=xgb.DMatrix(x_tr[start:start+batch_size], y_tr[start:start+batch_size]), xgb_model=model)\n", | |
"\n", | |
" y_pr = model.predict(xgb.DMatrix(x_te))\n", | |
" #print(' MSE itr@{}: {}'.format(int(start/batch_size), sklearn.metrics.mean_squared_error(y_te, y_pr)))\n", | |
" print('MSE itr@{}: {}'.format(i, sklearn.metrics.mean_squared_error(y_te, y_pr)))\n", | |
"\n", | |
"y_pr = model.predict(xgb.DMatrix(x_te))\n", | |
"print('MSE at the end: {}'.format(sklearn.metrics.mean_squared_error(y_te, y_pr)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Conclusion\n", | |
"MSE is decreasing with each iteration. Hence, the xgboost model is learning incrementally." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment