Last active
December 27, 2018 05:59
-
-
Save sfujiwara/cd838c122f344c934ab543ae0bad0011 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TensorFlowPruningHook(tf.train.SessionRunHook): | |
def __init__(self, trial, estimator, metrics_name, is_higher_better, run_every_steps): | |
self.trial = trial | |
self.estimator = estimator | |
self.current_step = -1 | |
self.metrics_name = metrics_name | |
self.is_higher_better = is_higher_better | |
self._global_step_tensor = None | |
self._timer = tf.train.SecondOrStepTimer(every_secs=None, every_steps=run_every_steps) | |
def begin(self): | |
self._global_step_tensor = tf.train.get_global_step() | |
def before_run(self, run_context): | |
del run_context | |
return tf.train.SessionRunArgs(self._global_step_tensor) | |
def after_run(self, run_context, run_values): | |
global_step = run_values.results | |
# Get eval metrics every n steps | |
if self._timer.should_trigger_for_step(global_step): | |
eval_metrics = tf.contrib.estimator.read_eval_metrics(self.estimator.eval_dir()) | |
else: | |
eval_metrics = None | |
if eval_metrics: | |
step = next(reversed(eval_metrics)) | |
latest_eval_metrics = eval_metrics[step] | |
# If there exists a new evaluation summary | |
if step > self.current_step: | |
if self.is_higher_better: | |
current_score = 1.0 - latest_eval_metrics[self.metrics_name] | |
else: | |
current_score = latest_eval_metrics[self.metrics_name] | |
self.trial.report(current_score, step=step) | |
self.current_step = step | |
if self.trial.should_prune(self.current_step): | |
message = "Trial was pruned at iteration {}.".format(self.current_step) | |
raise optuna.structs.TrialPruned(message) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment