Created
June 15, 2018 15:42
-
-
Save sseveran/e27045c5fdb2d2f836ca63e13755665f to your computer and use it in GitHub Desktop.
A Tensorflow hook for reporting state to ray-tune
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import six | |
import tensorflow as tf | |
from tensorflow.python.framework import ops | |
from tensorflow.python.training import training_util | |
from tensorflow.python.training.session_run_hook import SessionRunArgs | |
class RayTuneReportingHook(tf.train.SessionRunHook): | |
def __init__(self, params, reporter): | |
self.reporter = reporter | |
if not isinstance(params, dict): | |
self._tag_order = params | |
params = {item: item for item in params} | |
else: | |
self._tag_order = list(params.keys()) | |
self._tensors = params | |
def begin(self): | |
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access | |
self._current_tensors = {tag: _as_graph_element(tensor) for (tag, tensor) in self._tensors.items()} | |
def before_run(self, run_context): | |
return SessionRunArgs(self._current_tensors) | |
def after_run(self, | |
run_context, | |
run_values): | |
global_step = run_context.session.run(self._global_step_tensor) | |
results = {} | |
for tag in self._tag_order: | |
results[tag] = run_values.results[tag] | |
results['timesteps_total'] = global_step | |
self.reporter(**results) | |
#Yoinked from TF | |
def _as_graph_element(obj): | |
"""Retrieves Graph element.""" | |
graph = ops.get_default_graph() | |
if not isinstance(obj, six.string_types): | |
if not hasattr(obj, "graph") or obj.graph != graph: | |
raise ValueError("Passed %s should have graph attribute that is equal " | |
"to current graph %s." % (obj, graph)) | |
return obj | |
if ":" in obj: | |
element = graph.as_graph_element(obj) | |
else: | |
element = graph.as_graph_element(obj + ":0") | |
# Check that there is no :1 (e.g. it's single output). | |
try: | |
graph.as_graph_element(obj + ":1") | |
except (KeyError, ValueError): | |
pass | |
else: | |
raise ValueError("Name %s is ambiguous, " | |
"as this `Operation` has multiple outputs " | |
"(at least 2)." % obj) | |
return element |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ray_hook = RayTuneReportingHook(params={'mean_loss': 'sparse_softmax_cross_entropy_loss/value', | |
'mean_validation_accuracy': 'classification_accuracy/Mean'}, | |
reporter=reporter) | |
my_class.estimator(lambda: cross_validator.get_train_iterator(split, lambda x: my_class.parse_example(x)), | |
lambda: cross_validator.get_eval_iterator(split, lambda x: my_class.parse_example(x)), params, | |
max_steps=100000, eval_hooks=[ray_hook]) | |
#Notes: | |
#Set the ReportingHook params to a dict mapping the TrainableResult values to either tensors or tensor names. It should be | |
# able to resolve it. This will report a value to ray everytime eval is run. I have not figured out how to aggregate | |
# things like averages across batches in a single evaluation run. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment