Last active
November 5, 2024 12:15
-
-
Save aisu-wata0/24045ab5e8ad007b4a09f708adfe359f to your computer and use it in GitHub Desktop.
For optuna/optuna#2021. Class that prunes trials with the same parameters as past trials. You also can allow repeats if you want, and get the average of these past trials instead of pruning.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import optuna | |
from optuna.trial import TrialState | |
import numpy as np | |
from typing import Dict, List, Optional | |
from collections import defaultdict | |
class ParamRepeatPruner: | |
"""Prunes reapeated trials, which means trials with the same paramters won't waste time/resources.""" | |
def __init__( | |
self, | |
study: optuna.study.Study, | |
repeats_max: int = 0, | |
should_compare_states: List[TrialState] = [TrialState.COMPLETE], | |
compare_unfinished: bool = True, | |
): | |
""" | |
Args: | |
study (optuna.study.Study): Study of the trials. | |
repeats_max (int, optional): Instead of prunning all of them (not repeating trials at all, repeats_max=0) you can choose to repeat them up to a certain number of times, useful if your optimization function is not deterministic and gives slightly different results for the same params. Defaults to 0. | |
should_compare_states (List[TrialState], optional): By default it only skips the trial if the paremeters are equal to existing COMPLETE trials, so it repeats possible existing FAILed and PRUNED trials. If you also want to skip these trials then use [TrialState.COMPLETE,TrialState.FAIL,TrialState.PRUNED] for example. Defaults to [TrialState.COMPLETE]. | |
compare_unfinished (bool, optional): Unfinished trials (e.g. `RUNNING`) are treated like COMPLETE ones, if you don't want this behavior change this to False. Defaults to True. | |
""" | |
self.should_compare_states = should_compare_states | |
self.repeats_max = repeats_max | |
self.repeats: Dict[int, List[int]] = defaultdict(lambda: [], {}) | |
self.unfinished_repeats: Dict[int, List[int]] = defaultdict(lambda: [], {}) | |
self.compare_unfinished = compare_unfinished | |
self.study = study | |
@property | |
def study(self) -> Optional[optuna.study.Study]: | |
return self._study | |
@study.setter | |
def study(self, study): | |
self._study = study | |
if self.study is not None: | |
self.register_existing_trials() | |
def register_existing_trials(self): | |
"""In case of studies with existing trials, it counts existing repeats""" | |
trials = study.trials | |
trial_n = len(trials) | |
for trial_idx, trial_past in enumerate(study.trials[1:]): | |
self.check_params(trial_past, False, -trial_n + trial_idx) | |
def prune(self): | |
self.check_params() | |
def should_compare(self, state): | |
return any(state == state_comp for state_comp in self.should_compare_states) | |
def clean_unfinised_trials(self): | |
trials = self.study.trials | |
finished = [] | |
for key, value in self.unfinished_repeats.items(): | |
if self.should_compare(trials[key].state): | |
for t in value: | |
self.repeats[key].append(t) | |
finished.append(key) | |
for f in finished: | |
del self.unfinished_repeats[f] | |
def check_params( | |
self, | |
trial: Optional[optuna.trial.BaseTrial] = None, | |
prune_existing=True, | |
ignore_last_trial: Optional[int] = None, | |
): | |
if self.study is None: | |
return | |
trials = self.study.trials | |
if trial is None: | |
trial = trials[-1] | |
ignore_last_trial = -1 | |
self.clean_unfinised_trials() | |
self.repeated_idx = -1 | |
self.repeated_number = -1 | |
for idx_p, trial_past in enumerate(trials[:ignore_last_trial]): | |
should_compare = self.should_compare(trial_past.state) | |
should_compare |= ( | |
self.compare_unfinished and not trial_past.state.is_finished() | |
) | |
if should_compare and trial.params == trial_past.params: | |
if not trial_past.state.is_finished(): | |
self.unfinished_repeats[trial_past.number].append(trial.number) | |
continue | |
self.repeated_idx = idx_p | |
self.repeated_number = trial_past.number | |
break | |
if self.repeated_number > -1: | |
self.repeats[self.repeated_number].append(trial.number) | |
if len(self.repeats[self.repeated_number]) > self.repeats_max: | |
if prune_existing: | |
raise optuna.exceptions.TrialPruned() | |
return self.repeated_number | |
def get_value_of_repeats( | |
self, repeated_number: int, func=lambda value_list: np.mean(value_list) | |
): | |
if self.study is None: | |
raise ValueError("No study registered.") | |
trials = self.study.trials | |
values = ( | |
trials[repeated_number].value, | |
*( | |
trials[tn].value | |
for tn in self.repeats[repeated_number] | |
if trials[tn].value is not None | |
), | |
) | |
return func(values) | |
if __name__ == "__main__": | |
study = optuna.create_study( | |
sampler=optuna.samplers.TPESampler(seed=42), direction="minimize" | |
) | |
# Create "Pruner" | |
prune_params = ParamRepeatPruner(study) | |
# By default it only skips the trial if the paremeters are equal to existing COMPLETE trials, so it repeats possible existing FAILed and PRUNED trials. If you also want to skip these trials then just declare it like so: | |
# prune_params = ParamRepeatPruner(study, should_compare_states=[TrialState.COMPLETE,TrialState.FAIL,TrialState.PRUNED]) | |
# Check the constructor docstring for more information | |
def dummy_objective(trial: optuna.trial.Trial): | |
trial.suggest_int("dummy_param-0", 1, 20) | |
# Check parameters with the pruner | |
repeated = prune_params.check_params() | |
# # Instead of prunning you can return a mean of previous values, useful if allowing some repeats to happen in non deterministic objective functions | |
# repeated = prune_params.check_params(prune_existing=False) | |
# if repeated > -1: | |
# print("repeated") | |
# return prune_params.get_value_of_repeats(repeated) | |
return trial.params["dummy_param-0"] | |
study.optimize(dummy_objective, n_trials=40) | |
df = study.trials_dataframe() | |
df.to_csv("tmp_trials.csv", index=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
number | value | datetime_start | datetime_complete | duration | params_dummy_param-0 | state | |
---|---|---|---|---|---|---|---|
0 | 7.0 | 2020-12-20 23:23:29.965239 | 2020-12-20 23:23:29.965614 | 0 days 00:00:00.000375 | 7 | COMPLETE | |
1 | 20.0 | 2020-12-20 23:23:29.970232 | 2020-12-20 23:23:29.970573 | 0 days 00:00:00.000341 | 20 | COMPLETE | |
2 | 15.0 | 2020-12-20 23:23:29.974000 | 2020-12-20 23:23:29.974358 | 0 days 00:00:00.000358 | 15 | COMPLETE | |
3 | 11.0 | 2020-12-20 23:23:29.976618 | 2020-12-20 23:23:29.976979 | 0 days 00:00:00.000361 | 11 | COMPLETE | |
4 | 8.0 | 2020-12-20 23:23:29.985587 | 2020-12-20 23:23:29.986141 | 0 days 00:00:00.000554 | 8 | COMPLETE | |
5 | 2020-12-20 23:23:29.990534 | 2020-12-20 23:23:29.991239 | 0 days 00:00:00.000705 | 7 | PRUNED | ||
6 | 19.0 | 2020-12-20 23:23:29.994696 | 2020-12-20 23:23:29.995418 | 0 days 00:00:00.000722 | 19 | COMPLETE | |
7 | 2020-12-20 23:23:29.998019 | 2020-12-20 23:23:29.998734 | 0 days 00:00:00.000715 | 11 | PRUNED | ||
8 | 2020-12-20 23:23:30.002512 | 2020-12-20 23:23:30.003233 | 0 days 00:00:00.000721 | 11 | PRUNED | ||
9 | 4.0 | 2020-12-20 23:23:30.010106 | 2020-12-20 23:23:30.011037 | 0 days 00:00:00.000931 | 4 | COMPLETE | |
10 | 1.0 | 2020-12-20 23:23:30.014068 | 2020-12-20 23:23:30.015827 | 0 days 00:00:00.001759 | 1 | COMPLETE | |
11 | 2020-12-20 23:23:30.019763 | 2020-12-20 23:23:30.022023 | 0 days 00:00:00.002260 | 1 | PRUNED | ||
12 | 2020-12-20 23:23:30.026310 | 2020-12-20 23:23:30.028946 | 0 days 00:00:00.002636 | 1 | PRUNED | ||
13 | 3.0 | 2020-12-20 23:23:30.034922 | 2020-12-20 23:23:30.037205 | 0 days 00:00:00.002283 | 3 | COMPLETE | |
14 | 2020-12-20 23:23:30.040748 | 2020-12-20 23:23:30.043007 | 0 days 00:00:00.002259 | 3 | PRUNED | ||
15 | 2020-12-20 23:23:30.047450 | 2020-12-20 23:23:30.050039 | 0 days 00:00:00.002589 | 4 | PRUNED | ||
16 | 2020-12-20 23:23:30.054470 | 2020-12-20 23:23:30.056865 | 0 days 00:00:00.002395 | 1 | PRUNED | ||
17 | 5.0 | 2020-12-20 23:23:30.063282 | 2020-12-20 23:23:30.065678 | 0 days 00:00:00.002396 | 5 | COMPLETE | |
18 | 2.0 | 2020-12-20 23:23:30.078418 | 2020-12-20 23:23:30.080678 | 0 days 00:00:00.002260 | 2 | COMPLETE | |
19 | 14.0 | 2020-12-20 23:23:30.086380 | 2020-12-20 23:23:30.088794 | 0 days 00:00:00.002414 | 14 | COMPLETE | |
20 | 2020-12-20 23:23:30.103851 | 2020-12-20 23:23:30.114583 | 0 days 00:00:00.010732 | 1 | PRUNED | ||
21 | 2020-12-20 23:23:30.117459 | 2020-12-20 23:23:30.122207 | 0 days 00:00:00.004748 | 3 | PRUNED | ||
22 | 6.0 | 2020-12-20 23:23:30.124915 | 2020-12-20 23:23:30.129720 | 0 days 00:00:00.004805 | 6 | COMPLETE | |
23 | 2020-12-20 23:23:30.133369 | 2020-12-20 23:23:30.138773 | 0 days 00:00:00.005404 | 2 | PRUNED | ||
24 | 9.0 | 2020-12-20 23:23:30.141172 | 2020-12-20 23:23:30.146572 | 0 days 00:00:00.005400 | 9 | COMPLETE | |
25 | 2020-12-20 23:23:30.151306 | 2020-12-20 23:23:30.154078 | 0 days 00:00:00.002772 | 3 | PRUNED | ||
26 | 2020-12-20 23:23:30.158375 | 2020-12-20 23:23:30.163539 | 0 days 00:00:00.005164 | 5 | PRUNED | ||
27 | 2020-12-20 23:23:30.166948 | 2020-12-20 23:23:30.170811 | 0 days 00:00:00.003863 | 1 | PRUNED | ||
28 | 2020-12-20 23:23:30.174584 | 2020-12-20 23:23:30.180510 | 0 days 00:00:00.005926 | 5 | PRUNED | ||
29 | 2020-12-20 23:23:30.185130 | 2020-12-20 23:23:30.188175 | 0 days 00:00:00.003045 | 9 | PRUNED | ||
30 | 2020-12-20 23:23:30.194511 | 2020-12-20 23:23:30.200854 | 0 days 00:00:00.006343 | 2 | PRUNED | ||
31 | 2020-12-20 23:23:30.205900 | 2020-12-20 23:23:30.210159 | 0 days 00:00:00.004259 | 4 | PRUNED | ||
32 | 2020-12-20 23:23:30.216352 | 2020-12-20 23:23:30.219556 | 0 days 00:00:00.003204 | 3 | PRUNED | ||
33 | 2020-12-20 23:23:30.221033 | 2020-12-20 23:23:30.227350 | 0 days 00:00:00.006317 | 2 | PRUNED | ||
34 | 2020-12-20 23:23:30.230396 | 2020-12-20 23:23:30.234887 | 0 days 00:00:00.004491 | 6 | PRUNED | ||
35 | 2020-12-20 23:23:30.245149 | 2020-12-20 23:23:30.256016 | 0 days 00:00:00.010867 | 4 | PRUNED | ||
36 | 2020-12-20 23:23:30.261829 | 2020-12-20 23:23:30.267921 | 0 days 00:00:00.006092 | 7 | PRUNED | ||
37 | 13.0 | 2020-12-20 23:23:30.271507 | 2020-12-20 23:23:30.277886 | 0 days 00:00:00.006379 | 13 | COMPLETE | |
38 | 2020-12-20 23:23:30.283320 | 2020-12-20 23:23:30.318621 | 0 days 00:00:00.035301 | 6 | PRUNED | ||
39 | 17.0 | 2020-12-20 23:23:30.322163 | 2020-12-20 23:23:30.330106 | 0 days 00:00:00.007943 | 17 | COMPLETE |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Line 47 should be: "self.study.trials"
and line 49 should be: "for trial_idx, trial_past in enumerate(self.study.trials[1:]):"