Created
January 28, 2022 08:50
-
-
Save cschell/f4d2de50c9f34fddf38acef4b070ea92 to your computer and use it in GitHub Desktop.
custom version of Optuna's GridSampler, changing the default behaviour to ignore failed trials; the code only changes one line from the original (L23)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from optuna.samplers import GridSampler | |
from optuna.study import Study | |
from optuna.trial import TrialState | |
class CustomGridSampler(GridSampler): | |
def _get_unvisited_grid_ids(self, study: Study) -> List[int]: | |
# List up unvisited grids based on already finished ones. | |
visited_grids = [] | |
running_grids = [] | |
# We directly query the storage to get trials here instead of `study.get_trials`, | |
# since some pruners such as `HyperbandPruner` use the study transformed | |
# to filter trials. See https://github.com/optuna/optuna/issues/2327 for details. | |
trials = study._storage.get_all_trials(study._study_id, deepcopy=False) | |
for t in trials: | |
if "grid_id" in t.system_attrs and self._same_search_space( | |
t.system_attrs["search_space"] | |
): | |
if t.state in [TrialState.COMPLETE, TrialState.PRUNED]: | |
visited_grids.append(t.system_attrs["grid_id"]) | |
elif t.state == TrialState.RUNNING: | |
running_grids.append(t.system_attrs["grid_id"]) | |
unvisited_grids = set(range(self._n_min_trials)) - set(visited_grids) - set(running_grids) | |
# If evaluations for all grids have been started, return grids that have not yet finished | |
# because all grids should be evaluated before stopping the optimization. | |
if len(unvisited_grids) == 0: | |
unvisited_grids = set(range(self._n_min_trials)) - set(visited_grids) | |
return list(unvisited_grids) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment