Skip to content

Instantly share code, notes, and snippets.

@ian-plosker
Created May 21, 2016 02:31
Show Gist options
  • Save ian-plosker/995b522e0bb66a2d85421e2de5e61224 to your computer and use it in GitHub Desktop.
Save ian-plosker/995b522e0bb66a2d85421e2de5e61224 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.ops.dropout_ops import DROPOUTS
class PeriodicValidationMonitor(tf.contrib.learn.python.learn.monitors.ValidationMonitor):
def __init__(self,
val_X,
val_y,
n_classes=0,
validate_steps=100,
print_steps=100,
early_stopping_rounds=None):
super(PeriodicValidationMonitor, self).__init__(
val_X,
val_y,
n_classes=n_classes,
print_steps=print_steps,
early_stopping_rounds=early_stopping_rounds)
self.validate_steps=validate_steps
self._estimator = None
def create_val_feed_dict(self, inp, out):
# Dirty hack to ensure we disable droupout on validation
# https://github.com/tensorflow/tensorflow/issues/2171
import inspect
self._graph = inspect.currentframe().f_back.f_locals['self']._graph
dropouts = {prob: 1.0 for prob in self._graph.get_collection(DROPOUTS)}
self.val_feeder.set_placeholders(inp, out)
self.val_dict = self.val_feeder.get_feed_dict_fn()()
self.val_dict.update(dropouts)
def _set_last_loss_seen(self):
if (self.steps % self.validate_steps == 0):
super(PeriodicValidationMonitor, self)._set_last_loss_seen()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment