Created
May 17, 2023 09:36
-
-
Save c-bata/0e739d661e21a5bc4c2ddf2141bf6a9e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import numpy as np | |
from kurobako import problem | |
from kurobako.problem import Problem | |
from typing import List | |
from typing import Optional | |
class RastriginEvaluator(problem.Evaluator): | |
def __init__(self, params: List[Optional[float]]): | |
self.n = len(params) | |
self.x = np.array(params, dtype=float) | |
self._current_step = 0 | |
def evaluate(self, next_step: int) -> List[float]: | |
self._current_step = 1 | |
value = 10 * self.n + np.sum(self.x**2 - 10 * np.cos(2 * np.pi * self.x)) | |
return [value] | |
def current_step(self) -> int: | |
return self._current_step | |
class RastriginProblem(problem.Problem): | |
def create_evaluator( | |
self, params: List[Optional[float]] | |
) -> Optional[problem.Evaluator]: | |
return RastriginEvaluator(params) | |
class RastriginProblemFactory(problem.ProblemFactory): | |
def __init__(self, dim): | |
self.dim = dim | |
def create_problem(self, seed: int) -> Problem: | |
return RastriginProblem() | |
def specification(self) -> problem.ProblemSpec: | |
params = [ | |
problem.Var(f"x{i+1}", problem.ContinuousRange(-5.12, 5.12)) | |
for i in range(self.dim) | |
] | |
return problem.ProblemSpec( | |
name=f"Rastrigin (dim={self.dim})", | |
params=params, | |
values=[problem.Var("Rastrigin")], | |
) | |
if __name__ == "__main__": | |
dim = int(sys.argv[1]) if len(sys.argv) == 2 else 2 | |
runner = problem.ProblemRunner(RastriginProblemFactory(dim)) | |
runner.run() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import subprocess | |
def run(args: argparse.Namespace) -> None: | |
kurobako_cmd = os.path.join(args.path_to_kurobako, "kurobako") | |
subprocess.run(f"{kurobako_cmd} --version", shell=True) | |
os.makedirs(args.out_dir, exist_ok=True) | |
study_json_fn = os.path.join(args.out_dir, "studies.json") | |
solvers_filename = os.path.join(args.out_dir, "solvers.json") | |
problems_filename = os.path.join(args.out_dir, "problems.json") | |
# Ensure all files are empty. | |
for filename in [study_json_fn, solvers_filename, problems_filename]: | |
with open(filename, "w"): | |
pass | |
# Create Rastrigin-2D bench problem. | |
cmd = ( | |
f'{kurobako_cmd} problem command python problem_rastrigin.py 2 | tee -a {problems_filename}' | |
) | |
subprocess.run(cmd, shell=True) | |
# Create Optuna solvers | |
for name, sampler, sampler_kwargs in [ | |
("random", "RandomSampler", r"{}"), | |
("vanilla-cma-es", "CmaEsSampler", r"{}"), | |
("bipop-cma-es", "CmaEsSampler", r"{\"restart_strategy\":\"bipop\"}"), | |
("ipop-cma-es", "CmaEsSampler", r"{\"restart_strategy\":\"ipop\"}"), | |
]: | |
cmd = ( | |
f"{kurobako_cmd} solver --name {name} optuna --loglevel debug " | |
f"--sampler {sampler} --sampler-kwargs {sampler_kwargs} " | |
"--pruner NopPruner --pruner-kwargs {} " | |
f"| tee -a {solvers_filename}" | |
) | |
subprocess.run(cmd, shell=True) | |
# Create study. | |
cmd = ( | |
f"{kurobako_cmd} studies --budget {args.budget} " | |
f"--solvers $(cat {solvers_filename}) --problems $(cat {problems_filename}) " | |
f"--repeats {args.n_runs} --seed {args.seed} --concurrency {args.n_concurrency} " | |
f"> {study_json_fn}" | |
) | |
subprocess.run(cmd, shell=True, check=True) | |
result_filename = os.path.join(args.out_dir, "results.json") | |
cmd = ( | |
f"cat {study_json_fn} | {kurobako_cmd} run --parallelism {args.n_jobs} " | |
f"> {result_filename}" | |
) | |
subprocess.run(cmd, shell=True) | |
report_filename = os.path.join(args.out_dir, "report.md") | |
cmd = f"cat {result_filename} | {kurobako_cmd} report > {report_filename}" | |
subprocess.run(cmd, shell=True) | |
cmd = ( | |
f"cat {result_filename} | docker run -v $(pwd)/{args.out_dir}/images:/images/ " | |
f"--rm -i sile/kurobako plot curve --errorbar --xmin 10" | |
) | |
subprocess.run(cmd, shell=True) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--path-to-kurobako", type=str, default="") | |
parser.add_argument("--budget", type=int, default=5000) | |
parser.add_argument("--n-runs", type=int, default=10) | |
parser.add_argument("--n-jobs", type=int, default=10) | |
parser.add_argument("--n-concurrency", type=int, default=1) | |
parser.add_argument("--seed", type=int, default=0) | |
parser.add_argument("--out-dir", type=str, default="tmp/benchmark_report") | |
args = parser.parse_args() | |
run(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
c-bata/goptuna#136