Skip to content

Instantly share code, notes, and snippets.

@stsievert
Last active June 30, 2018 01:08
Show Gist options
  • Save stsievert/b797b4e4d510a4e417a8db33eb6601c6 to your computer and use it in GitHub Desktop.
Save stsievert/b797b4e4d510a4e417a8db33eb6601c6 to your computer and use it in GitHub Desktop.
XGBoost + Dask example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dask and XGBoost\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dask and XGBoost can work together to train gradient boosted trees in parallel. This notebook shows how, and is easily runnable.\n",
"\n",
"XGBoost provides a powerful prediction framework, and it works well in practice though it's not well understood. It wins Kaggle contests and is popular in industry, because it has good performance (i.e., high accuracy models) and can be easily interpreted (i.e., it's easy to find the important features from a XGBoost model).\n",
"\n",
"The goals of this notebook are to show Dask and XGBoost working together, and explain a little bit of what they do together."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"<img src=\"http://dask.readthedocs.io/en/latest/_images/dask_horizontal.svg\" width=\"30%\" alt=\"Dask logo\"> <img src=\"https://raw.githubusercontent.com/dmlc/dmlc.github.io/master/img/logo-m/xgboost.png\" width=\"25%\" alt=\"Dask logo\">"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table style=\"border: 2px solid white;\">\n",
"<tr>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3>Client</h3>\n",
"<ul>\n",
" <li><b>Scheduler: </b>tcp://127.0.0.1:60822\n",
" <li><b>Dashboard: </b><a href='http://127.0.0.1:8787/status' target='_blank'>http://127.0.0.1:8787/status</a>\n",
"</ul>\n",
"</td>\n",
"<td style=\"vertical-align: top; border: 0px solid white\">\n",
"<h3>Cluster</h3>\n",
"<ul>\n",
" <li><b>Workers: </b>8</li>\n",
" <li><b>Cores: </b>8</li>\n",
" <li><b>Memory: </b>17.18 GB</li>\n",
"</ul>\n",
"</td>\n",
"</tr>\n",
"</table>"
],
"text/plain": [
"<Client: scheduler='tcp://127.0.0.1:60822' processes=8 cores=8>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from dask import compute, persist\n",
"from dask.distributed import Client, progress\n",
"\n",
"client = Client()\n",
"client"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data creation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we create a bunch of synthetic data, with realistic sizes:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import dask.array as da\n",
"from dask_ml.model_selection import train_test_split\n",
"import numpy as np\n",
"\n",
"num_samples, num_features = int(100e3), 20\n",
"da.random.seed(42) # so results are deterministic; not important\n",
"\n",
"X = da.random.normal(size=(num_samples, num_features),\n",
" chunks=num_samples // 100)\n",
"w_star = da.random.uniform(size=num_features,\n",
" chunks=num_features)**3\n",
"y = da.sign(X @ w_star)\n",
"\n",
"# for binary:logistic objective, only [0, 1] labels allowed\n",
"y = (y + 1) / 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's separate into a training set and testing set, which will allow for good evaluation after training."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's try to do something with this data using [dask-xgboost][dxgb].\n",
"\n",
"[dxgb]:https://github.com/dask/dask-xgboost"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import dask_xgboost\n",
"import xgboost"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"dask-xgboost is a small wrapper around xgboost, and will behave the same as xgboost.\n",
"\n",
"During training Dask will take care of loading, cleaning, and pre-processing the data. XGBoost will leverage their own distributed training system using all the workers that Dask has available. Dask sets XGBoost up, gives XGBoost data and lets XGBoost do it's training in the background using all the workers Dask has available."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's do some training:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.04 s, sys: 58.6 ms, total: 1.09 s\n",
"Wall time: 2.47 s\n"
]
}
],
"source": [
"%%time\n",
"params = {'objective': 'binary:logistic', 'nround': 1000, \n",
" 'max_depth': 5, 'eta': 0.01, 'subsample': 0.5, \n",
" 'min_child_weight': 0.5}\n",
"\n",
"bst = dask_xgboost.train(client, params, X_train, y_train)\n",
"bst"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Result visualization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `bst` object is a regular `xgboost.Booster` object. This means all the methods mentioned in the [XGBoost documentation][2] are available. Here's one example with plotting:\n",
"\n",
"[2]:https://xgboost.readthedocs.io/en/latest/python/python_intro.html#"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x576 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import dask\n",
"\n",
"w_star = dask.compute(w_star)[0]\n",
"idx_top = np.argsort(w_star)[-7:]\n",
"top = w_star[idx_top]\n",
"\n",
"fig, axs = plt.subplots(figsize=(12, 8), ncols=2)\n",
"\n",
"axs[0] = xgboost.plot_importance(bst, ax=axs[0], height=0.8, max_num_features=9)\n",
"\n",
"axs[0].grid(False, axis=\"y\")\n",
"axs[0].set_title('Estimated feature importance')\n",
"\n",
"df = pd.DataFrame({'feature_importance': top}, index=idx_top)\n",
"df.plot.bar(ax=axs[1])\n",
"axs[1].set_title('Ground truth feature importance')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see from this example that XGBoost does a good job at feature *support recovery*, or figuring out which features are important. Notice that in both plots, feature `9` shows up in both plots (in the left plot as `f9`). It has them misordered a bit, but it still does a pretty decent job: the top 5 estimated most important features: the top 5 *estimated* most important features are *actually* the top 5 most important features."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"But, we can use a fancier metric to determine how well our classifier is doing by plotting the Receiver Operating Characteristic (ROC) curve:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dask.array<_predict_part, shape=(15000,), dtype=float32, chunksize=(150,)>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_hat = dask_xgboost.predict(client, bst, X_test).persist()\n",
"y_hat"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import roc_curve, auc\n",
"fig, ax = plt.subplots(figsize=(5, 5))\n",
"fpr, tpr, _ = roc_curve(y_test, y_hat)\n",
"ax.plot(fpr, tpr, lw=3,\n",
" label='ROC Curve (area = {:.2f})'.format(auc(fpr, tpr)))\n",
"ax.plot([0, 1], [0, 1], 'k--', lw=2)\n",
"\n",
"ax.set(\n",
" xlim=(0, 1),\n",
" ylim=(0, 1),\n",
" title=\"ROC Curve\",\n",
" xlabel=\"False Positive Rate\",\n",
" ylabel=\"True Positive Rate\",\n",
")\n",
"ax.legend();\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This Receiver Operating Characteristic (ROC) curve tells how well our classifier is doing. We can tell it's doing well by how far it bends the upper-left. A perfect classifier would be in the upper-left corner, and a random classifier would follow the horizontal line.\n",
"\n",
"The area under this curve is `area = 0.89`. This tells us the probability that our classifier will predict correctly for a randomly chosen instance."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learn more\n",
"* XGBoost documentation: https://xgboost.readthedocs.io/en/latest/python/python_intro.html#\n",
"* A blogpost on dask-xgboost http://matthewrocklin.com/blog/work/2017/03/28/dask-xgboost\n",
"* Similar example with real world dataset: https://dask-ml.readthedocs.io/en/latest/examples/xgboost.html\n",
"* Recorded screencast of a core Dask developer stepping through the example above at https://www.youtube.com/watch?v=Cc4E-PdDSro \n",
"* ROC curve: https://en.wikipedia.org/wiki/Receiver_operating_characteristic"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment