Skip to content

Instantly share code, notes, and snippets.

@aisu-wata0
Last active November 5, 2024 12:15
Show Gist options
  • Save aisu-wata0/24045ab5e8ad007b4a09f708adfe359f to your computer and use it in GitHub Desktop.
Save aisu-wata0/24045ab5e8ad007b4a09f708adfe359f to your computer and use it in GitHub Desktop.
For optuna/optuna#2021. Class that prunes trials with the same parameters as past trials. You also can allow repeats if you want, and get the average of these past trials instead of pruning.
import optuna
from optuna.trial import TrialState
import numpy as np
from typing import Dict, List, Optional
from collections import defaultdict
class ParamRepeatPruner:
"""Prunes reapeated trials, which means trials with the same paramters won't waste time/resources."""
def __init__(
self,
study: optuna.study.Study,
repeats_max: int = 0,
should_compare_states: List[TrialState] = [TrialState.COMPLETE],
compare_unfinished: bool = True,
):
"""
Args:
study (optuna.study.Study): Study of the trials.
repeats_max (int, optional): Instead of prunning all of them (not repeating trials at all, repeats_max=0) you can choose to repeat them up to a certain number of times, useful if your optimization function is not deterministic and gives slightly different results for the same params. Defaults to 0.
should_compare_states (List[TrialState], optional): By default it only skips the trial if the paremeters are equal to existing COMPLETE trials, so it repeats possible existing FAILed and PRUNED trials. If you also want to skip these trials then use [TrialState.COMPLETE,TrialState.FAIL,TrialState.PRUNED] for example. Defaults to [TrialState.COMPLETE].
compare_unfinished (bool, optional): Unfinished trials (e.g. `RUNNING`) are treated like COMPLETE ones, if you don't want this behavior change this to False. Defaults to True.
"""
self.should_compare_states = should_compare_states
self.repeats_max = repeats_max
self.repeats: Dict[int, List[int]] = defaultdict(lambda: [], {})
self.unfinished_repeats: Dict[int, List[int]] = defaultdict(lambda: [], {})
self.compare_unfinished = compare_unfinished
self.study = study
@property
def study(self) -> Optional[optuna.study.Study]:
return self._study
@study.setter
def study(self, study):
self._study = study
if self.study is not None:
self.register_existing_trials()
def register_existing_trials(self):
"""In case of studies with existing trials, it counts existing repeats"""
trials = study.trials
trial_n = len(trials)
for trial_idx, trial_past in enumerate(study.trials[1:]):
self.check_params(trial_past, False, -trial_n + trial_idx)
def prune(self):
self.check_params()
def should_compare(self, state):
return any(state == state_comp for state_comp in self.should_compare_states)
def clean_unfinised_trials(self):
trials = self.study.trials
finished = []
for key, value in self.unfinished_repeats.items():
if self.should_compare(trials[key].state):
for t in value:
self.repeats[key].append(t)
finished.append(key)
for f in finished:
del self.unfinished_repeats[f]
def check_params(
self,
trial: Optional[optuna.trial.BaseTrial] = None,
prune_existing=True,
ignore_last_trial: Optional[int] = None,
):
if self.study is None:
return
trials = self.study.trials
if trial is None:
trial = trials[-1]
ignore_last_trial = -1
self.clean_unfinised_trials()
self.repeated_idx = -1
self.repeated_number = -1
for idx_p, trial_past in enumerate(trials[:ignore_last_trial]):
should_compare = self.should_compare(trial_past.state)
should_compare |= (
self.compare_unfinished and not trial_past.state.is_finished()
)
if should_compare and trial.params == trial_past.params:
if not trial_past.state.is_finished():
self.unfinished_repeats[trial_past.number].append(trial.number)
continue
self.repeated_idx = idx_p
self.repeated_number = trial_past.number
break
if self.repeated_number > -1:
self.repeats[self.repeated_number].append(trial.number)
if len(self.repeats[self.repeated_number]) > self.repeats_max:
if prune_existing:
raise optuna.exceptions.TrialPruned()
return self.repeated_number
def get_value_of_repeats(
self, repeated_number: int, func=lambda value_list: np.mean(value_list)
):
if self.study is None:
raise ValueError("No study registered.")
trials = self.study.trials
values = (
trials[repeated_number].value,
*(
trials[tn].value
for tn in self.repeats[repeated_number]
if trials[tn].value is not None
),
)
return func(values)
if __name__ == "__main__":
study = optuna.create_study(
sampler=optuna.samplers.TPESampler(seed=42), direction="minimize"
)
# Create "Pruner"
prune_params = ParamRepeatPruner(study)
# By default it only skips the trial if the paremeters are equal to existing COMPLETE trials, so it repeats possible existing FAILed and PRUNED trials. If you also want to skip these trials then just declare it like so:
# prune_params = ParamRepeatPruner(study, should_compare_states=[TrialState.COMPLETE,TrialState.FAIL,TrialState.PRUNED])
# Check the constructor docstring for more information
def dummy_objective(trial: optuna.trial.Trial):
trial.suggest_int("dummy_param-0", 1, 20)
# Check parameters with the pruner
repeated = prune_params.check_params()
# # Instead of prunning you can return a mean of previous values, useful if allowing some repeats to happen in non deterministic objective functions
# repeated = prune_params.check_params(prune_existing=False)
# if repeated > -1:
# print("repeated")
# return prune_params.get_value_of_repeats(repeated)
return trial.params["dummy_param-0"]
study.optimize(dummy_objective, n_trials=40)
df = study.trials_dataframe()
df.to_csv("tmp_trials.csv", index=False)
number value datetime_start datetime_complete duration params_dummy_param-0 state
0 7.0 2020-12-20 23:23:29.965239 2020-12-20 23:23:29.965614 0 days 00:00:00.000375 7 COMPLETE
1 20.0 2020-12-20 23:23:29.970232 2020-12-20 23:23:29.970573 0 days 00:00:00.000341 20 COMPLETE
2 15.0 2020-12-20 23:23:29.974000 2020-12-20 23:23:29.974358 0 days 00:00:00.000358 15 COMPLETE
3 11.0 2020-12-20 23:23:29.976618 2020-12-20 23:23:29.976979 0 days 00:00:00.000361 11 COMPLETE
4 8.0 2020-12-20 23:23:29.985587 2020-12-20 23:23:29.986141 0 days 00:00:00.000554 8 COMPLETE
5 2020-12-20 23:23:29.990534 2020-12-20 23:23:29.991239 0 days 00:00:00.000705 7 PRUNED
6 19.0 2020-12-20 23:23:29.994696 2020-12-20 23:23:29.995418 0 days 00:00:00.000722 19 COMPLETE
7 2020-12-20 23:23:29.998019 2020-12-20 23:23:29.998734 0 days 00:00:00.000715 11 PRUNED
8 2020-12-20 23:23:30.002512 2020-12-20 23:23:30.003233 0 days 00:00:00.000721 11 PRUNED
9 4.0 2020-12-20 23:23:30.010106 2020-12-20 23:23:30.011037 0 days 00:00:00.000931 4 COMPLETE
10 1.0 2020-12-20 23:23:30.014068 2020-12-20 23:23:30.015827 0 days 00:00:00.001759 1 COMPLETE
11 2020-12-20 23:23:30.019763 2020-12-20 23:23:30.022023 0 days 00:00:00.002260 1 PRUNED
12 2020-12-20 23:23:30.026310 2020-12-20 23:23:30.028946 0 days 00:00:00.002636 1 PRUNED
13 3.0 2020-12-20 23:23:30.034922 2020-12-20 23:23:30.037205 0 days 00:00:00.002283 3 COMPLETE
14 2020-12-20 23:23:30.040748 2020-12-20 23:23:30.043007 0 days 00:00:00.002259 3 PRUNED
15 2020-12-20 23:23:30.047450 2020-12-20 23:23:30.050039 0 days 00:00:00.002589 4 PRUNED
16 2020-12-20 23:23:30.054470 2020-12-20 23:23:30.056865 0 days 00:00:00.002395 1 PRUNED
17 5.0 2020-12-20 23:23:30.063282 2020-12-20 23:23:30.065678 0 days 00:00:00.002396 5 COMPLETE
18 2.0 2020-12-20 23:23:30.078418 2020-12-20 23:23:30.080678 0 days 00:00:00.002260 2 COMPLETE
19 14.0 2020-12-20 23:23:30.086380 2020-12-20 23:23:30.088794 0 days 00:00:00.002414 14 COMPLETE
20 2020-12-20 23:23:30.103851 2020-12-20 23:23:30.114583 0 days 00:00:00.010732 1 PRUNED
21 2020-12-20 23:23:30.117459 2020-12-20 23:23:30.122207 0 days 00:00:00.004748 3 PRUNED
22 6.0 2020-12-20 23:23:30.124915 2020-12-20 23:23:30.129720 0 days 00:00:00.004805 6 COMPLETE
23 2020-12-20 23:23:30.133369 2020-12-20 23:23:30.138773 0 days 00:00:00.005404 2 PRUNED
24 9.0 2020-12-20 23:23:30.141172 2020-12-20 23:23:30.146572 0 days 00:00:00.005400 9 COMPLETE
25 2020-12-20 23:23:30.151306 2020-12-20 23:23:30.154078 0 days 00:00:00.002772 3 PRUNED
26 2020-12-20 23:23:30.158375 2020-12-20 23:23:30.163539 0 days 00:00:00.005164 5 PRUNED
27 2020-12-20 23:23:30.166948 2020-12-20 23:23:30.170811 0 days 00:00:00.003863 1 PRUNED
28 2020-12-20 23:23:30.174584 2020-12-20 23:23:30.180510 0 days 00:00:00.005926 5 PRUNED
29 2020-12-20 23:23:30.185130 2020-12-20 23:23:30.188175 0 days 00:00:00.003045 9 PRUNED
30 2020-12-20 23:23:30.194511 2020-12-20 23:23:30.200854 0 days 00:00:00.006343 2 PRUNED
31 2020-12-20 23:23:30.205900 2020-12-20 23:23:30.210159 0 days 00:00:00.004259 4 PRUNED
32 2020-12-20 23:23:30.216352 2020-12-20 23:23:30.219556 0 days 00:00:00.003204 3 PRUNED
33 2020-12-20 23:23:30.221033 2020-12-20 23:23:30.227350 0 days 00:00:00.006317 2 PRUNED
34 2020-12-20 23:23:30.230396 2020-12-20 23:23:30.234887 0 days 00:00:00.004491 6 PRUNED
35 2020-12-20 23:23:30.245149 2020-12-20 23:23:30.256016 0 days 00:00:00.010867 4 PRUNED
36 2020-12-20 23:23:30.261829 2020-12-20 23:23:30.267921 0 days 00:00:00.006092 7 PRUNED
37 13.0 2020-12-20 23:23:30.271507 2020-12-20 23:23:30.277886 0 days 00:00:00.006379 13 COMPLETE
38 2020-12-20 23:23:30.283320 2020-12-20 23:23:30.318621 0 days 00:00:00.035301 6 PRUNED
39 17.0 2020-12-20 23:23:30.322163 2020-12-20 23:23:30.330106 0 days 00:00:00.007943 17 COMPLETE
@bl4val
Copy link

bl4val commented Sep 23, 2022

Line 47 should be: "self.study.trials"
and line 49 should be: "for trial_idx, trial_past in enumerate(self.study.trials[1:]):"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment