Last active
January 27, 2020 15:36
-
-
Save sile/47a4c257da31a76c6e8460f2747afb71 to your computer and use it in GitHub Desktop.
Kurobako blog: random.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A solver implementation based on Random Search algorithm. | |
from kurobako import problem | |
from kurobako import solver | |
import numpy as np | |
class RandomSolverFactory(solver.SolverFactory): | |
def specification(self): | |
return solver.SolverSpec(name='Random Search') | |
def create_solver(self, seed, problem): | |
return RandomSolver(seed, problem) | |
class RandomSolver(solver.Solver): | |
def __init__(self, seed, problem): | |
self._rng = np.random.RandomState(seed) | |
self._problem = problem | |
def ask(self, idg): | |
trial_id = idg.generate() | |
next_step = self._problem.steps.last_step | |
params = [] | |
for p in self._problem.params: | |
if p.distribution == problem.Distribution.LOG_UNIFORM: | |
low = np.log(p.range.low) | |
high = np.log(p.range.high) | |
params.append(float(np.exp(self._rng.uniform(log, high)))) | |
else: | |
params.append(self._rng.uniform(p.range.low, p.range.high)) | |
return solver.NextTrial(trial_id, params, next_step) | |
def tell(self, trial): | |
pass | |
if __name__ == '__main__': | |
runner = solver.SolverRunner(RandomSolverFactory()) | |
runner.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment