Skip to content

Instantly share code, notes, and snippets.

@sfujiwara
Last active December 27, 2018 05:59
Show Gist options
  • Save sfujiwara/cd838c122f344c934ab543ae0bad0011 to your computer and use it in GitHub Desktop.
Save sfujiwara/cd838c122f344c934ab543ae0bad0011 to your computer and use it in GitHub Desktop.
class TensorFlowPruningHook(tf.train.SessionRunHook):
def __init__(self, trial, estimator, metrics_name, is_higher_better, run_every_steps):
self.trial = trial
self.estimator = estimator
self.current_step = -1
self.metrics_name = metrics_name
self.is_higher_better = is_higher_better
self._global_step_tensor = None
self._timer = tf.train.SecondOrStepTimer(every_secs=None, every_steps=run_every_steps)
def begin(self):
self._global_step_tensor = tf.train.get_global_step()
def before_run(self, run_context):
del run_context
return tf.train.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
global_step = run_values.results
# Get eval metrics every n steps
if self._timer.should_trigger_for_step(global_step):
eval_metrics = tf.contrib.estimator.read_eval_metrics(self.estimator.eval_dir())
else:
eval_metrics = None
if eval_metrics:
step = next(reversed(eval_metrics))
latest_eval_metrics = eval_metrics[step]
# If there exists a new evaluation summary
if step > self.current_step:
if self.is_higher_better:
current_score = 1.0 - latest_eval_metrics[self.metrics_name]
else:
current_score = latest_eval_metrics[self.metrics_name]
self.trial.report(current_score, step=step)
self.current_step = step
if self.trial.should_prune(self.current_step):
message = "Trial was pruned at iteration {}.".format(self.current_step)
raise optuna.structs.TrialPruned(message)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment