Skip to content

Instantly share code, notes, and snippets.

@sseveran
Created June 15, 2018 15:42
Show Gist options
  • Save sseveran/e27045c5fdb2d2f836ca63e13755665f to your computer and use it in GitHub Desktop.
Save sseveran/e27045c5fdb2d2f836ca63e13755665f to your computer and use it in GitHub Desktop.
A Tensorflow hook for reporting state to ray-tune
import six
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
class RayTuneReportingHook(tf.train.SessionRunHook):
def __init__(self, params, reporter):
self.reporter = reporter
if not isinstance(params, dict):
self._tag_order = params
params = {item: item for item in params}
else:
self._tag_order = list(params.keys())
self._tensors = params
def begin(self):
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
self._current_tensors = {tag: _as_graph_element(tensor) for (tag, tensor) in self._tensors.items()}
def before_run(self, run_context):
return SessionRunArgs(self._current_tensors)
def after_run(self,
run_context,
run_values):
global_step = run_context.session.run(self._global_step_tensor)
results = {}
for tag in self._tag_order:
results[tag] = run_values.results[tag]
results['timesteps_total'] = global_step
self.reporter(**results)
#Yoinked from TF
def _as_graph_element(obj):
"""Retrieves Graph element."""
graph = ops.get_default_graph()
if not isinstance(obj, six.string_types):
if not hasattr(obj, "graph") or obj.graph != graph:
raise ValueError("Passed %s should have graph attribute that is equal "
"to current graph %s." % (obj, graph))
return obj
if ":" in obj:
element = graph.as_graph_element(obj)
else:
element = graph.as_graph_element(obj + ":0")
# Check that there is no :1 (e.g. it's single output).
try:
graph.as_graph_element(obj + ":1")
except (KeyError, ValueError):
pass
else:
raise ValueError("Name %s is ambiguous, "
"as this `Operation` has multiple outputs "
"(at least 2)." % obj)
return element
ray_hook = RayTuneReportingHook(params={'mean_loss': 'sparse_softmax_cross_entropy_loss/value',
'mean_validation_accuracy': 'classification_accuracy/Mean'},
reporter=reporter)
my_class.estimator(lambda: cross_validator.get_train_iterator(split, lambda x: my_class.parse_example(x)),
lambda: cross_validator.get_eval_iterator(split, lambda x: my_class.parse_example(x)), params,
max_steps=100000, eval_hooks=[ray_hook])
#Notes:
#Set the ReportingHook params to a dict mapping the TrainableResult values to either tensors or tensor names. It should be
# able to resolve it. This will report a value to ray everytime eval is run. I have not figured out how to aggregate
# things like averages across batches in a single evaluation run.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment