Created
April 9, 2023 22:20
-
-
Save kstoneriv3/1e877830e535d9ebea0f9d55ed50734c to your computer and use it in GitHub Desktop.
Benchmark script of batched sampling with botorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from concurrent.futures import ProcessPoolExecutor, wait | |
import fire | |
from matplotlib import pyplot as plt | |
import numpy as np | |
import optuna | |
import os | |
import time | |
import torch | |
N_WORKERS = 10 | |
N_TRIALS = 40 | |
def func(trial): | |
X = [trial.suggest_int(f"x_{i}", -5, 5) for i in range(10)] | |
time.sleep(60 * np.random.rand()) | |
return sum(x ** 2 for x in X) | |
def get_study(consider_running_trials=False, seed=None): | |
# torch.cuda.set_per_process_memory_fraction(fraction=0.8 / N_WORKERS, device='cuda:0') | |
storage = optuna.storages.JournalStorage( | |
optuna.storages.JournalFileStorage("./journal.log"), | |
) | |
pruner = optuna.pruners.NopPruner() | |
sampler = optuna.integration.BoTorchSampler( | |
consider_running_trials=consider_running_trials, seed=seed, # device="cuda:0" | |
) | |
study = optuna.create_study( | |
study_name="study_0", storage=storage, sampler=sampler, load_if_exists=True | |
) | |
return study | |
def optimize(i, consider_running_trials): | |
study = get_study(consider_running_trials, i) | |
study.optimize(func, n_trials=1) | |
def run(consider_running_trials): | |
try: | |
os.remove("./journal.log") | |
except OSError: | |
pass | |
study = get_study() | |
with ProcessPoolExecutor(max_workers=N_WORKERS) as executor: | |
futures = executor.map(optimize, range(N_TRIALS), [consider_running_trials] * N_TRIALS) | |
list(futures) # wait here | |
return study | |
def benchmark(n_iter=10): | |
torch.multiprocessing.set_start_method('spawn') | |
results = { | |
consider_running_trials: [run(consider_running_trials).best_value for i in range(n_iter)] | |
for consider_running_trials in (False, True) | |
} | |
print(results) | |
print("mean:") | |
print({k: np.mean(v) for k, v in results.items()}) | |
print("std:") | |
print({k: np.std(v) / np.sqrt(n_iter) for k, v in results.items()}) | |
if __name__ == "__main__": | |
# fire.Fire(main) | |
benchmark(20) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment