Created
June 16, 2022 04:44
-
-
Save jakirkham/a597ca5929679f6ccecdb1f03bab62f1 to your computer and use it in GitHub Desktop.
Notebook loading Higgs dataset for use with Dask + XGBoost
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "19f93a66-538c-4f3e-9e74-079a556716e9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ[\"DASK_DISTRIBUTED__SCHEDULER__WORK_STEALING\"] = \"False\" # needs to be added to dask-scheduler\n", | |
"\n", | |
"from functools import partial\n", | |
"from itertools import starmap\n", | |
"from operator import attrgetter, getitem\n", | |
"from math import ceil\n", | |
"\n", | |
"from tlz import sliding_window\n", | |
"\n", | |
"from dask_cuda import LocalCUDACluster\n", | |
"from dask.distributed import Client, wait\n", | |
"from dask import delayed\n", | |
"import dask_cudf\n", | |
"import cudf\n", | |
"import distributed\n", | |
"import xgboost as xgb\n", | |
"import time\n", | |
"from dask.utils import stringify\n", | |
"\n", | |
"from sklearn.metrics import mean_absolute_error" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b1a1e2ff-be48-44a6-8767-265d417de1b5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def reproducible_persist_per_worker(df, client):\n", | |
" # Query workers\n", | |
" n_workers = len(client.cluster.workers)\n", | |
" workers = map(attrgetter(\"worker_address\"), client.cluster.workers.values())\n", | |
"\n", | |
" # Slice data into roughly equal partitions\n", | |
" subpartition_size = ceil(df.npartitions / n_workers)\n", | |
" subpartition_divisions = range(0, df.npartitions + subpartition_size, subpartition_size)\n", | |
" subpartition_slices = starmap(slice, sliding_window(2, subpartition_divisions))\n", | |
" subpartitions = map(partial(getitem, df.partitions), subpartition_slices)\n", | |
"\n", | |
" # Persist each subpartition on each worker\n", | |
" # Rebuild dataframe from persisted subpartitions\n", | |
" df2 = dask_cudf.concat([sp.persist(workers=w, allow_other_workers=False) for sp, w in zip(subpartitions, workers)])\n", | |
"\n", | |
" return df2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ed98e366-1450-45d6-9c2f-a3897a36cdae", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n_workers = 8\n", | |
"cluster = LocalCUDACluster(n_workers=n_workers)\n", | |
"client = Client(cluster)\n", | |
"client.wait_for_workers(n_workers)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c9dcf799-e12f-4704-b928-70e824536b99", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"fname = 'HIGGS.csv'\n", | |
"colnames = ['label'] + ['feature-%02d' % i for i in range(1, 29)]\n", | |
"df = dask_cudf.read_csv(fname, header=None, names=colnames)\n", | |
"df = reproducible_persist_per_worker2(df, client)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2c96374b-9c16-4543-a4eb-61600894f638", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"df_features = df.drop(columns=['label'])\n", | |
"df_labels = df['label']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "c9b00774-c794-492c-8b0d-cf2d60f0b97a", | |
"metadata": { | |
"tags": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"dmatrix = xgb.dask.DaskDeviceQuantileDMatrix(client=client,\n", | |
" data=df_features,\n", | |
" label=df_labels)\n", | |
"\n", | |
"model = xgb.dask.train(client,\n", | |
" {'verbosity': 0,\n", | |
" 'tree_method': 'gpu_hist',\n", | |
" 'seed': 123},\n", | |
" dtrain=dmatrix,\n", | |
" num_boost_round=3000,\n", | |
" evals=[(dmatrix,'dtrain')])\n", | |
"\n", | |
"print(\"Final train loss: \", model['history']['dtrain']['rmse'][-1])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "72ec3c6b-3536-49b0-aab7-5128479135ba", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y_pred = xgb.dask.predict(client, model, df_features).to_frame().compute()\n", | |
"y_pred = y_pred.rename({0: 'score'}, axis=1)\n", | |
"\n", | |
"y_pred.to_parquet(\"result.parquet\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "2bc2d19e-15b1-4bcc-94ab-f45d08cbd776", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.13" | |
}, | |
"widgets": { | |
"application/vnd.jupyter.widget-state+json": { | |
"state": {}, | |
"version_major": 2, | |
"version_minor": 0 | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment