Created
March 18, 2020 10:12
-
-
Save toshihikoyanase/6253fe69f055e69d299ea5d4042331bd to your computer and use it in GitHub Desktop.
Examples for Optuna #1039 Support pruning/resume/parallelization for LightGBMTuner
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Optuna example that optimizes a classifier configuration for cancer dataset using LightGBM tuner. | |
In this example, we optimize the validation log loss of cancer detection. | |
You can execute this code directly. | |
$ python lightgbm_tuner_parallel.py [-p] | |
""" | |
import argparse | |
import json | |
import os | |
import pickle | |
from joblib import Parallel, delayed | |
import numpy as np | |
import sklearn.datasets | |
from sklearn.metrics import accuracy_score | |
from sklearn.model_selection import train_test_split | |
import optuna | |
import optuna.integration.lightgbm as lgb | |
if __name__ == '__main__': | |
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True) | |
train_x, val_x, train_y, val_y = train_test_split(data, target, test_size=0.25) | |
dtrain = lgb.Dataset(train_x, label=train_y) | |
dval = lgb.Dataset(val_x, label=val_y) | |
params = { | |
'objective': 'binary', | |
'metric': 'binary_logloss', | |
'verbosity': -1, | |
'boosting_type': 'gbdt', | |
} | |
parser = argparse.ArgumentParser(description='PyTorch Ignite example.') | |
parser.add_argument('--pruning', '-p', action='store_true', | |
help='Activate the pruning feature. `MedianPruner` stops unpromising ' | |
'trials at the early stages of training.') | |
args = parser.parse_args() | |
pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() | |
study = optuna.create_study(storage="sqlite:///lgbtuner.db", pruner=pruner) | |
# Note: cloudpickle cannot dump LightGBM Booster instances. Do not return them. | |
def train(): | |
lgb.train( | |
params, | |
dtrain, | |
valid_sets=[dtrain, dval], | |
verbose_eval=100, | |
early_stopping_rounds=100, | |
study=study | |
) | |
Parallel(n_jobs=2)(delayed(train)() for _ in range(2)) | |
best_trial = study.best_trial | |
path = os.path.join("/tmp", "{}.pkl".format(best_trial.number)) | |
with open(path, "rb") as fin: | |
model = pickle.load(fin) | |
prediction = np.rint(model.predict(val_x, num_iteration=model.best_iteration)) | |
accuracy = accuracy_score(val_y, prediction) | |
best_trial = study.best_trial | |
best_params = json.loads(best_trial.user_attrs['lgbm_params']) | |
print('Number of finished trials: {}'.format(len(study.trials))) | |
print('Best params:', best_params) | |
print(' Accuracy = {}'.format(accuracy)) | |
print(' Params: ') | |
for key, value in best_params.items(): | |
print(' {}: {}'.format(key, value)) | |
study.trials_dataframe().to_csv('parallel-result.csv') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Optuna example that optimizes a classifier configuration for cancer dataset using LightGBM tuner. | |
In this example, we optimize the validation log loss of cancer detection. | |
You can execute this code directly. | |
$ python lightgbm_tuner_pruning.py [-p] | |
""" | |
import argparse | |
import json | |
import numpy as np | |
import sklearn.datasets | |
from sklearn.metrics import accuracy_score | |
from sklearn.model_selection import train_test_split | |
import optuna | |
import optuna.integration.lightgbm as lgb | |
if __name__ == '__main__': | |
data, target = sklearn.datasets.load_breast_cancer(return_X_y=True) | |
train_x, val_x, train_y, val_y = train_test_split(data, target, test_size=0.25) | |
dtrain = lgb.Dataset(train_x, label=train_y) | |
dval = lgb.Dataset(val_x, label=val_y) | |
params = { | |
'objective': 'binary', | |
'metric': 'binary_logloss', | |
'verbosity': -1, | |
'boosting_type': 'gbdt', | |
} | |
parser = argparse.ArgumentParser(description='PyTorch Ignite example.') | |
parser.add_argument('--pruning', '-p', action='store_true', | |
help='Activate the pruning feature. `MedianPruner` stops unpromising ' | |
'trials at the early stages of training.') | |
args = parser.parse_args() | |
pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner() | |
study = optuna.create_study(pruner=pruner) | |
model = lgb.train(params, | |
dtrain, | |
valid_sets=[dtrain, dval], | |
verbose_eval=100, | |
early_stopping_rounds=100, | |
study=study | |
) | |
prediction = np.rint(model.predict(val_x, num_iteration=model.best_iteration)) | |
accuracy = accuracy_score(val_y, prediction) | |
best_trial = study.best_trial | |
best_params = json.loads(best_trial.user_attrs['lgbm_params']) | |
print('Number of finished trials: {}'.format(len(study.trials))) | |
print('Best params:', best_params) | |
print(' Accuracy = {}'.format(accuracy)) | |
print(' Params: ') | |
for key, value in best_params.items(): | |
print(' {}: {}'.format(key, value)) | |
study.trials_dataframe().to_csv('pruning-result.csv') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment