Skip to content

Instantly share code, notes, and snippets.

@sfujiwara
Last active October 8, 2020 04:17
Show Gist options
  • Save sfujiwara/fa3b763943a1d4758a01c89a8f3df2a5 to your computer and use it in GitHub Desktop.
Save sfujiwara/fa3b763943a1d4758a01c89a8f3df2a5 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import optuna
import sklearn.datasets
from sklearn.model_selection import train_test_split
class TensorFlowPruningHook(tf.train.SessionRunHook):
def __init__(self, trial, estimator, metric, is_higher_better, run_every_steps):
self.trial = trial
self.estimator = estimator
self.current_step = -1
self.metric = metric
self.is_higher_better = is_higher_better
self._global_step_tensor = None
self._timer = tf.train.SecondOrStepTimer(every_secs=None, every_steps=run_every_steps)
def begin(self):
self._global_step_tensor = tf.train.get_global_step()
def before_run(self, run_context):
del run_context
return tf.train.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
global_step = run_values.results
# Get eval metrics every n steps
if self._timer.should_trigger_for_step(global_step):
eval_metrics = tf.contrib.estimator.read_eval_metrics(self.estimator.eval_dir())
else:
eval_metrics = None
if eval_metrics:
step = next(reversed(eval_metrics))
latest_eval_metrics = eval_metrics[step]
# If there exists a new evaluation summary
if step > self.current_step:
if self.is_higher_better:
current_score = 1.0 - latest_eval_metrics[self.metrics_name]
else:
current_score = latest_eval_metrics[self.metrics_name]
self.trial.report(current_score, step=step)
self.current_step = step
if self.trial.should_prune(self.current_step):
message = "Trial was pruned at iteration {}.".format(self.current_step)
raise optuna.structs.TrialPruned(message)
def create_input_fn():
iris = sklearn.datasets.load_iris()
x, y = iris.data, iris.target
x_train, x_eval, y_train, y_eval = train_test_split(x, y, test_size=0.5, random_state=42)
def _train_input_fn():
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(128).repeat().batch(16)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return {"x": features}, labels
def _eval_input_fn():
dataset = tf.data.Dataset.from_tensor_slices((x_eval, y_eval))
dataset = dataset.batch(32)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return {"x": features}, labels
return _train_input_fn, _eval_input_fn
def objective(trial):
save_steps = 50
# Create input functions for train and eval
train_input_fn, eval_input_fn = create_input_fn()
# Hyper parameters to be tuned with Optuna
learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-2)
# Create Estimator config
config = tf.estimator.RunConfig(save_summary_steps=save_steps, save_checkpoints_steps=save_steps)
# Create Estimator
clf = tf.estimator.DNNClassifier(
feature_columns=[tf.feature_column.numeric_column(key="x", shape=[4])],
n_classes=3,
hidden_units=[],
optimizer=tf.train.GradientDescentOptimizer(learning_rate=learning_rate),
model_dir="outputs_pruning/lr_{}".format(learning_rate),
config=config
)
# Create hooks
early_stopping_hook = tf.contrib.estimator.stop_if_no_decrease_hook(clf, "accuracy", save_steps)
optuna_pruning_hook = TensorFlowPruningHook(
trial=trial,
estimator=clf,
metric="accuracy",
is_higher_better=True,
run_every_steps=10,
)
hooks = [early_stopping_hook, optuna_pruning_hook]
# Create TrainSpec and EvalSpec
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=500, hooks=hooks)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=1000, start_delay_secs=0, throttle_secs=0)
# Run training and evaluation
tf.estimator.train_and_evaluate(clf, train_spec, eval_spec)
result = clf.evaluate(input_fn=eval_input_fn, steps=100)
accuracy = result["accuracy"]
return 1.0 - accuracy
if __name__ == "__main__":
study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=100))
study.optimize(objective, n_trials=15)
print(study.best_trial)
print([t.state for t in study.trials])
@kination
Copy link

Thanks for great work!
But I got an error:

[W 2019-01-28 18:33:49,614] Setting trial status as TrialState.FAIL because of the following error: AttributeError("'TensorFlowPruningHook' object has no attribute 'metrics_name'",)

Seems code

...
                if self.is_higher_better:
                    current_score = 1.0 - latest_eval_metrics[self.metrics_name]
                else:
                    current_score = latest_eval_metrics[self.metrics_name]
...

needs to be changed as

                if self.is_higher_better:
                    current_score = 1.0 - latest_eval_metrics[self.metric]
                else:
                    current_score = latest_eval_metrics[self.metric]

Am I right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment