Skip to content

Instantly share code, notes, and snippets.

@hvy
Last active January 15, 2021 07:12
Show Gist options
  • Save hvy/67a279e843a43a1cfc04c2eb92ea605d to your computer and use it in GitHub Desktop.
Save hvy/67a279e843a43a1cfc04c2eb92ea605d to your computer and use it in GitHub Desktop.
Optuna storage benchmark script.
import argparse
import math
import time
import sqlalchemy
import optuna
class Profile:
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end = time.time()
def get(self):
return self.end - self.start
def build_objective_func(n_param):
def objective(trial):
return sum(
[
math.sin(trial.suggest_float("param-{}".format(i), 0, math.pi * 2))
for i in range(n_param)
]
)
return objective
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("mysql_user", type=str)
parser.add_argument("mysql_host", type=str)
parser.add_argument("--n-params", type=int, nargs="+", default=[1, 2, 4, 8, 16, 32])
parser.add_argument("--n-trials", type=int, nargs="+", default=[1, 10, 100, 1000])
args = parser.parse_args()
storage_str = "mysql+pymysql://{}@{}/".format(args.mysql_user, args.mysql_host)
n_params = args.n_params
n_trials = args.n_trials
optuna.logging.set_verbosity(optuna.logging.CRITICAL)
print(f"| #params | #trials | time(sec) | time/trial(sec) |")
print(f"| ------- | ------- | --------- | --------------- |")
for n_param in n_params:
for n_trial in n_trials:
engine = sqlalchemy.create_engine(storage_str)
conn = engine.connect()
conn.execute("commit")
database_str = "profile_storage_t{}_p{}".format(n_trial, n_param)
try:
conn.execute("drop database {}".format(database_str))
except Exception:
pass
conn.execute("create database {}".format(database_str))
conn.close()
storage = optuna.storages.get_storage(storage_str + database_str)
study = optuna.create_study(storage=storage, sampler=optuna.samplers.RandomSampler())
with Profile() as prof:
study.optimize(
build_objective_func(n_param),
n_trials=n_trial,
gc_after_trial=False,
)
print(f"| {n_param} | {n_trial} | {prof.get():.2f} | {prof.get() / n_trial:.3f} |")
@hvy
Copy link
Author

hvy commented Jan 15, 2021

python optuna_storage_benchmark.py root localhost --n-trials 1 10 100

master

#params #trials time(sec) time/trial(sec)
1 1 0.02 0.022
1 10 0.15 0.015
1 100 1.58 0.016
1 500 8.32 0.017
2 1 0.02 0.023
2 10 0.15 0.015
2 100 1.50 0.015
2 500 8.47 0.017
4 1 0.03 0.031
4 10 0.17 0.017
4 100 1.59 0.016
4 500 9.23 0.018
8 1 0.05 0.046
8 10 0.33 0.033
8 100 1.83 0.018
8 500 9.52 0.019
16 1 0.08 0.077
16 10 0.27 0.027
16 100 2.04 0.020
16 500 11.72 0.023
32 1 0.15 0.148
32 10 0.42 0.042
32 100 2.98 0.030
32 500 13.92 0.028

Flushing after _CachedStorage.set_trial_param

#params #trials time(sec) time/trial(sec)
1 1 0.02 0.021
1 10 0.15 0.015
1 100 1.52 0.015
1 500 9.01 0.018
2 1 0.02 0.023
2 10 0.20 0.020
2 100 1.90 0.019
2 500 10.60 0.021
4 1 0.03 0.031
4 10 0.27 0.027
4 100 2.69 0.027
4 500 16.20 0.032
8 1 0.05 0.046
8 10 0.53 0.053
8 100 4.84 0.048
8 500 26.64 0.053
16 1 0.08 0.075
16 10 0.76 0.076
16 100 8.28 0.083
16 500 41.02 0.082
32 1 0.14 0.140
32 10 1.59 0.159
32 100 16.26 0.163
32 500 79.58 0.159

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment