Last active
September 8, 2021 09:23
-
-
Save belltailjp/e7a33d985ade8cefcb939212d39002fe to your computer and use it in GitHub Desktop.
ppe-config based optuna
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import optuna | |
# https://optuna.org/ | |
def objective(trial): | |
x = trial.suggest_uniform('x', -10, 10) | |
return (x - 2) ** 2 | |
def main(): | |
study = optuna.create_study() | |
study.optimize(objective, n_trials=100) | |
print(study.best_params) | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yaml | |
import optuna | |
import pytorch_pfn_extras as ppe | |
class Net: | |
def __init__(self, value): | |
self.value = value | |
def get_value(self): | |
return (self.value - 2) ** 2 | |
cfg_yml = """ | |
net: | |
type: Net | |
value: 0.4 | |
""" | |
types = { | |
'Net': Net, | |
} | |
def run_train(net): | |
# Run a training and return the final accuracy of this run | |
return net.get_value() | |
def main(): | |
cfg = yaml.safe_load(cfg_yml) | |
cfg = ppe.config.Config(cfg, types) | |
net = cfg['/net'] | |
acc = run_train(net) | |
print(acc) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yaml | |
import optuna | |
import pytorch_pfn_extras as ppe | |
import pytorch_pfn_extras.config_types as config_types | |
class Net: | |
def __init__(self, value): | |
self.value = value | |
def get_value(self): | |
return (self.value - 2) ** 2 | |
cfg_yml = """ | |
net: | |
type: Net | |
value: | |
type: optuna_suggest_float | |
name: x | |
low: -10.0 | |
high: 10.0 | |
""" | |
types = { | |
'Net': Net, | |
} | |
def run_train(net): | |
# Run a training and return the final accuracy of this run | |
acc = net.get_value() | |
return acc | |
def main(): | |
# These params cannot be included in the config yaml | |
n_trials = 100 | |
sampler_type = optuna.samplers.TPESampler | |
sampler_kwargs = dict() | |
def objective(trial): | |
cfg = yaml.safe_load(cfg_yml) | |
cfg = ppe.config.Config(cfg, {**types, **config_types.optuna_types(trial)}) | |
net = cfg['/net'] | |
return run_train(net) | |
if True: | |
# Use SQLite (it creates (or loads if already exists) my_study.db in the cwd) | |
storage = 'sqlite:///my_study.db' | |
else: | |
# Use memory | |
storage = None | |
study = optuna.create_study(sampler=sampler_type(**sampler_kwargs), storage=storage) | |
study.optimize(objective, n_trials=n_trials) | |
print(study.best_params) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment