Last active
October 8, 2020 04:17
-
-
Save sfujiwara/fa3b763943a1d4758a01c89a8f3df2a5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import optuna | |
import sklearn.datasets | |
from sklearn.model_selection import train_test_split | |
class TensorFlowPruningHook(tf.train.SessionRunHook): | |
def __init__(self, trial, estimator, metric, is_higher_better, run_every_steps): | |
self.trial = trial | |
self.estimator = estimator | |
self.current_step = -1 | |
self.metric = metric | |
self.is_higher_better = is_higher_better | |
self._global_step_tensor = None | |
self._timer = tf.train.SecondOrStepTimer(every_secs=None, every_steps=run_every_steps) | |
def begin(self): | |
self._global_step_tensor = tf.train.get_global_step() | |
def before_run(self, run_context): | |
del run_context | |
return tf.train.SessionRunArgs(self._global_step_tensor) | |
def after_run(self, run_context, run_values): | |
global_step = run_values.results | |
# Get eval metrics every n steps | |
if self._timer.should_trigger_for_step(global_step): | |
eval_metrics = tf.contrib.estimator.read_eval_metrics(self.estimator.eval_dir()) | |
else: | |
eval_metrics = None | |
if eval_metrics: | |
step = next(reversed(eval_metrics)) | |
latest_eval_metrics = eval_metrics[step] | |
# If there exists a new evaluation summary | |
if step > self.current_step: | |
if self.is_higher_better: | |
current_score = 1.0 - latest_eval_metrics[self.metrics_name] | |
else: | |
current_score = latest_eval_metrics[self.metrics_name] | |
self.trial.report(current_score, step=step) | |
self.current_step = step | |
if self.trial.should_prune(self.current_step): | |
message = "Trial was pruned at iteration {}.".format(self.current_step) | |
raise optuna.structs.TrialPruned(message) | |
def create_input_fn(): | |
iris = sklearn.datasets.load_iris() | |
x, y = iris.data, iris.target | |
x_train, x_eval, y_train, y_eval = train_test_split(x, y, test_size=0.5, random_state=42) | |
def _train_input_fn(): | |
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) | |
dataset = dataset.shuffle(128).repeat().batch(16) | |
iterator = dataset.make_one_shot_iterator() | |
features, labels = iterator.get_next() | |
return {"x": features}, labels | |
def _eval_input_fn(): | |
dataset = tf.data.Dataset.from_tensor_slices((x_eval, y_eval)) | |
dataset = dataset.batch(32) | |
iterator = dataset.make_one_shot_iterator() | |
features, labels = iterator.get_next() | |
return {"x": features}, labels | |
return _train_input_fn, _eval_input_fn | |
def objective(trial): | |
save_steps = 50 | |
# Create input functions for train and eval | |
train_input_fn, eval_input_fn = create_input_fn() | |
# Hyper parameters to be tuned with Optuna | |
learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-2) | |
# Create Estimator config | |
config = tf.estimator.RunConfig(save_summary_steps=save_steps, save_checkpoints_steps=save_steps) | |
# Create Estimator | |
clf = tf.estimator.DNNClassifier( | |
feature_columns=[tf.feature_column.numeric_column(key="x", shape=[4])], | |
n_classes=3, | |
hidden_units=[], | |
optimizer=tf.train.GradientDescentOptimizer(learning_rate=learning_rate), | |
model_dir="outputs_pruning/lr_{}".format(learning_rate), | |
config=config | |
) | |
# Create hooks | |
early_stopping_hook = tf.contrib.estimator.stop_if_no_decrease_hook(clf, "accuracy", save_steps) | |
optuna_pruning_hook = TensorFlowPruningHook( | |
trial=trial, | |
estimator=clf, | |
metric="accuracy", | |
is_higher_better=True, | |
run_every_steps=10, | |
) | |
hooks = [early_stopping_hook, optuna_pruning_hook] | |
# Create TrainSpec and EvalSpec | |
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=500, hooks=hooks) | |
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=1000, start_delay_secs=0, throttle_secs=0) | |
# Run training and evaluation | |
tf.estimator.train_and_evaluate(clf, train_spec, eval_spec) | |
result = clf.evaluate(input_fn=eval_input_fn, steps=100) | |
accuracy = result["accuracy"] | |
return 1.0 - accuracy | |
if __name__ == "__main__": | |
study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=100)) | |
study.optimize(objective, n_trials=15) | |
print(study.best_trial) | |
print([t.state for t in study.trials]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for great work!
But I got an error:
Seems code
needs to be changed as
Am I right?