Last active
December 7, 2020 04:45
-
-
Save hvy/86c7d33ebda94f5ad74169ba39a3b52d to your computer and use it in GitHub Desktop.
Optuna example that optimizes a simple quadratic function in parallel using joblib with arbitrary arguments to the objective function.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Optuna example that optimizes a simple quadratic function in parallel using `joblib` allowing | |
arbitrary arguments to the objective function. | |
Run the example as follows. | |
$ python quadratic_joblib_simple.py | |
If you need to rerun the example and thus delete previous studies, you can use the Optuna CLI. | |
$ optuna delete-study --study-name joblib-quadratic --storage "sqlite:///example.db" | |
See also | |
https://optuna.readthedocs.io/en/latest/faq.html#how-to-define-objective-functions-that-have-own-arguments | |
""" | |
from joblib import Parallel, delayed | |
import optuna | |
def print_study(study): | |
print('Number of finished trials: ', len(study.trials)) | |
print('Best trial:') | |
trial = study.best_trial | |
print(' Value: ', trial.value) | |
print(' Params: ') | |
for key, value in trial.params.items(): | |
print(' {}: {}'.format(key, value)) | |
def optimize(n_trials, min_x, max_x): | |
study = optuna.load_study(study_name='joblib-quadratic', storage='sqlite:///example.db') | |
# You can either use a lambda (as shown here) or define a class that holds the arguments and | |
# implements `__call__`. | |
study.optimize(lambda trial: objective(trial, min_x, max_x), n_trials=n_trials) | |
# An objective function does not only take the trial, but also additional arguments. | |
def objective(trial, min_x, max_x): | |
x = trial.suggest_uniform('x', min_x, max_x) | |
return (x - 2) ** 2 | |
if __name__ == '__main__': | |
study = optuna.create_study(study_name='joblib-quadratic', storage='sqlite:///example.db') | |
# `Study.optimize` arguments. | |
n_trials = 10 | |
# Arbitrary arguments to the objective function. | |
min_x = -100 | |
max_x = 100 | |
# `joblib` arguments. | |
n_iterables = 3 | |
r = Parallel(n_jobs=-1)( | |
[delayed(optimize)(n_trials, min_x, max_x) for _ in range(n_iterables)]) | |
assert len(study.trials) == n_trials * n_iterables | |
print_study(study) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment