Skip to content

Instantly share code, notes, and snippets.

@peterroelants
Created August 8, 2017 08:32
Show Gist options
  • Save peterroelants/6a7b3cc802f7f855744e3a74a1fab354 to your computer and use it in GitHub Desktop.
Save peterroelants/6a7b3cc802f7f855744e3a74a1fab354 to your computer and use it in GitHub Desktop.
Tensorflow Dataset API initialiser hook fix
from __future__ import division, print_function, absolute_import, \
unicode_literals
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data as mnist_data
from tensorflow.contrib import slim
from tensorflow.contrib.learn import ModeKeys
from tensorflow.contrib.learn.python.learn import learn_runner
tf.logging.set_verbosity(tf.logging.DEBUG)
# Define data loaders #####################################
MNIST_DATA_FOLDER = './MNIST_data'
mnist = mnist_data.read_data_sets(MNIST_DATA_FOLDER, one_hot=False)
class IteratorInitializerHook(tf.train.SessionRunHook):
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initiliser_func = None
def after_create_session(self, session, coord):
self.iterator_initiliser_func(session)
# Define the training inputs
def get_train_inputs(batch_size):
iterator_initiliser_hook = IteratorInitializerHook()
def train_inputs():
with tf.name_scope('Training_data'):
images = mnist.train.images.reshape([-1, 28, 28, 1])
labels = mnist.train.labels
images_placeholder = tf.placeholder(
images.dtype, images.shape)
labels_placeholder = tf.placeholder(
labels.dtype, labels.shape)
dataset = tf.contrib.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder))
dataset = dataset.repeat(None) # Infinite
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
# Set runhook to initilise iterator
iterator_initiliser_hook.iterator_initiliser_func = \
lambda sess: sess.run(
iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
return next_example, next_label
return train_inputs, iterator_initiliser_hook
def get_test_inputs(batch_size):
iterator_initiliser_hook = IteratorInitializerHook()
def test_inputs():
with tf.name_scope('Test_data'):
images = mnist.test.images.reshape([-1, 28, 28, 1])
labels = mnist.test.labels
images_placeholder = tf.placeholder(
images.dtype, images.shape)
labels_placeholder = tf.placeholder(
labels.dtype, labels.shape)
dataset = tf.contrib.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder))
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
# Set runhook to initilise iterator
iterator_initiliser_hook.iterator_initiliser_func = \
lambda sess: sess.run(
iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
return next_example, next_label
return test_inputs, iterator_initiliser_hook
# Define model ############################################
def get_estimator(run_config, hparams):
return tf.contrib.learn.Estimator(
model_fn=get_model_fn,
params=hparams,
config=run_config
)
def get_model_fn(inputs, targets, mode, params):
# Define model's architecture
logits = architecture(inputs, is_training=mode == ModeKeys.TRAIN)
head = tf.contrib.learn.multi_class_head(
n_classes=params.n_classes,
loss_fn=tf.losses.sparse_softmax_cross_entropy)
return head.create_model_fn_ops(
features={'inputs': inputs},
labels=tf.cast(targets, tf.int32),
mode=mode,
train_op_fn=get_train_op_fn(params),
logits=logits
)
def get_train_op_fn(params):
def train_op_fn(loss):
return tf.contrib.layers.optimize_loss(
loss=loss,
global_step=tf.contrib.framework.get_global_step(),
optimizer=tf.train.AdamOptimizer,
learning_rate=params.learning_rate
)
return train_op_fn
def architecture(inputs, is_training, scope='MnistConvNet'):
tf.logging.debug('is_training: {}, {}'.format(type(is_training), is_training))
with tf.variable_scope(scope):
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_initializer=tf.contrib.layers.xavier_initializer()):
net = slim.conv2d(inputs, 20, [5, 5], padding='VALID',
scope='layer1-conv')
net = slim.max_pool2d(net, 2, stride=2, scope='layer2-max-pool')
net = slim.conv2d(net, 40, [5, 5], padding='VALID',
scope='layer3-conv')
net = slim.max_pool2d(net, 2, stride=2, scope='layer4-max-pool')
net = tf.reshape(net, [-1, 4 * 4 * 40])
net = slim.fully_connected(net, 256, scope='layer5')
net = slim.dropout(net, is_training=is_training,
scope='layer5-dropout')
net = slim.fully_connected(net, 256, scope='layer6')
net = slim.dropout(net, is_training=is_training,
scope='layer6-dropout')
net = slim.fully_connected(net, 10, scope='output',
activation_fn=None)
return net
def create_experiment(run_config, hparams):
# You can change a subset of the run_config properties as
run_config = run_config.replace(save_checkpoints_steps=500)
# Define the mnist classifier
estimator = get_estimator(run_config, hparams)
# Setup data loaders
train_input_fn, train_input_hook = get_train_inputs(batch_size=128)
eval_input_fn, eval_input_hook = get_test_inputs(batch_size=128)
# Define the experiment
experiment = tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=5000,
min_eval_frequency=500,
train_monitors=[train_input_hook],
eval_hooks=[eval_input_hook],
eval_steps=None # Use evaluation feeder until its empty
)
return experiment
def main(unused_argv=None):
model_dir = './mnist_training_new_dataset_hook'
# Define model parameters
hparams = tf.contrib.training.HParams(
learning_rate=0.002,
n_classes=10
)
learn_runner.run(
experiment_fn=create_experiment,
run_config=tf.contrib.learn.RunConfig(model_dir=model_dir),
schedule="train_and_evaluate",
hparams=hparams)
if __name__ == "__main__":
tf.app.run()
@wohlbier
Copy link

@yzhangswingman I hit this exact problem. Did you manage to solve it?

@yzhang5471
Copy link

@wohlbier oh hey, I actually cannot recall at all. Sorry about that. Don’t know if this is helpful but I switched to tf.keras api altogether since TF2 since it seems to be more stable and better maintained.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment