Last active
October 23, 2020 11:59
-
-
Save stsievert/94cecc2679ddedc07a1887080d4afa19 to your computer and use it in GitHub Desktop.
Hyperband demo with sklearn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Hyperband\n", | |
"Hyperband is useful when limited by computational resources. Some examples cases are when\n", | |
"\n", | |
"* there are many parameters to search over\n", | |
"* models take a long time to train\n", | |
"\n", | |
"Hyperband does only require *one* input, the computational budget. For more information on this, see the documentation: https://dask-ml.readthedocs.io/en/latest/hyper-parameter-search.html\n", | |
"\n", | |
"Hyperband is an *adaptive* algorithm: it spends as much time as possible on high-performing models by \"killing\" off the lower portion. More detail in mentioned in the `HyperbandCV` class description: https://dask-ml.readthedocs.io/en/latest/modules/generated/dask_ml.model_selection.GridSearchCV.html#dask_ml.model_selection.HyperbandCV\n", | |
"\n", | |
"Below, we'll simulate having many parameters to search over by having one parameters. We would have two, but we want to have a easy-to-interpret visualization at the end.\n", | |
"\n", | |
"Hyperband is very similar to `RandomizedSearchCV` and works best with continuous random variables. We simulate log-uniform random variable with lots of samples: `np.logspace(-4, 1, num=1000)`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import dask.array as da\n", | |
"\n", | |
"from sklearn.linear_model import SGDClassifier\n", | |
"\n", | |
"import dask_ml\n", | |
"from dask_ml.datasets import make_classification\n", | |
"from dask_ml.wrappers import Incremental\n", | |
"from dask_ml.model_selection import HyperbandCV, GridSearchCV, train_test_split" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from distributed import Client, LocalCluster\n", | |
"client = Client()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 39, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n, d = int(10e3), int(100)\n", | |
"X, y = make_classification(n_features=d, n_samples=n,\n", | |
" n_informative=d // 10,\n", | |
" chunks=(n // 10, d))\n", | |
"classes = da.unique(y)\n", | |
"X_train, X_test, y_train, y_test = train_test_split(X, y)\n", | |
"\n", | |
"kwargs = dict(penalty='elasticnet', max_iter=1.0, warm_start=True, loss='log')\n", | |
"model = Incremental(SGDClassifier(**kwargs))\n", | |
"params = {'alpha': np.logspace(-4, 1, num=1000)}\n", | |
"# 'loss': ['hinge', 'log', 'modified_huber', 'squared_hinge']}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"alg = HyperbandCV(model, params, max_iter=81)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 47, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 2.21 s, sys: 223 ms, total: 2.44 s\n", | |
"Wall time: 2.68 s\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"HyperbandCV(asynchronous=True, eta=3, max_iter=81,\n", | |
" model=Incremental(SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,\n", | |
" eta0=0.0, fit_intercept=True, l1_ratio=0.15,\n", | |
" learning_rate='optimal', loss='log', max_iter=1.0, n_iter=None,\n", | |
" n_jobs=1, penalty='elasticnet', power_t=0.5, random_state=None,\n", | |
" shuffle=True, tol=None, verbose=0, warm_start=True)),\n", | |
" params={'alpha': array([1.00000e-04, 1.01159e-04, ..., 9.88542e+00, 1.00000e+01])},\n", | |
" random_state=None, scoring=None, test_size=0.15)" | |
] | |
}, | |
"execution_count": 47, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"alg.fit(X_train, y_train, classes=da.unique(y))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 48, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.774" | |
] | |
}, | |
"execution_count": 48, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"alg.score(X_test, y_test)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 49, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'alpha': 0.11562801312073753}" | |
] | |
}, | |
"execution_count": 49, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hyperband_alpha = alg.best_params_['alpha']\n", | |
"# hyperband_loss = alg.best_params_['loss']\n", | |
"# hyperban\n", | |
"alg.best_params_" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now will we compare with an exhaustive evaluation, which we can do because we're only simulating being computationally limited.\n", | |
"\n", | |
"We will use `GridSearchCV`, and set the loss of the model to be the loss Hyperband found. We'll do this because this is the really the only to show an fair visualization: otherwise we're comparing `alpha`s across loss functions, which doesn't make sense.\n", | |
"\n", | |
"Note that this visualization hides the fact that Hyperband was searching between 5 different loss functions." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 50, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 4.42 s, sys: 2.45 s, total: 6.86 s\n", | |
"Wall time: 1.16 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"# model.set_params(loss=hyperband_loss)\n", | |
"params = {'alpha': np.logspace(-4, 1, num=50)}\n", | |
"grid = GridSearchCV(model.estimator, params, return_train_score=False)\n", | |
"grid.fit(X, y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 51, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'alpha': 0.03556480306223129}" | |
] | |
}, | |
"execution_count": 51, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"grid.best_params_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 52, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"import pandas as pd\n", | |
"df = pd.DataFrame(grid.cv_results_)\n", | |
"\n", | |
"fig, ax = plt.subplots()\n", | |
"df.plot(x='param_alpha', y='mean_test_score',\n", | |
" yerr='std_test_score',\n", | |
" logx=True, ax=ax)\n", | |
"ax.plot(2 * [hyperband_alpha], plt.ylim(), 'r--',\n", | |
" label=\"Hyperband's chosen alpha\")\n", | |
"plt.legend(loc='lower left')\n", | |
"plt.ylabel('mean_test_score')\n", | |
"plt.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
I'm unsure what you're asking about. I think the HyperbandSearchCV docs, user guide and blog post are better resources than this gist.
Just commenting in case other people stumble upon it too. Wanted to use it
as an example, but it ended up being deprecated.
On Thu, Oct 22, 2020 at 3:57 PM Scott Sievert ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
I'm unsure what you're asking about. I think the HyperbandSearchCV docs
<https://ml.dask.org/modules/generated/dask_ml.model_selection.HyperbandSearchCV.html#dask_ml.model_selection.HyperbandSearchCV>,
user guide
<https://ml.dask.org/hyper-parameter-search.html#adaptive-hyperparameter-optimization>
and blog post <https://blog.dask.org/2019/09/30/dask-hyperparam-opt> are
better resources than this gist.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<https://gist.github.com/94cecc2679ddedc07a1887080d4afa19#gistcomment-3500268>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AADZR6DSTUHIMTYWAC7NYW3SMCFDZANCNFSM4S3URZPA>
.
--
--
Sincerely yours,
Dave Liu
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
HyperbandSearchCV now
algo.fit gives "Too many values to unpack" error