Created
August 8, 2017 08:32
-
-
Save peterroelants/6a7b3cc802f7f855744e3a74a1fab354 to your computer and use it in GitHub Desktop.
Tensorflow Dataset API initialiser hook fix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division, print_function, absolute_import, \ | |
unicode_literals | |
import tensorflow as tf | |
from tensorflow.examples.tutorials.mnist import input_data as mnist_data | |
from tensorflow.contrib import slim | |
from tensorflow.contrib.learn import ModeKeys | |
from tensorflow.contrib.learn.python.learn import learn_runner | |
tf.logging.set_verbosity(tf.logging.DEBUG) | |
# Define data loaders ##################################### | |
MNIST_DATA_FOLDER = './MNIST_data' | |
mnist = mnist_data.read_data_sets(MNIST_DATA_FOLDER, one_hot=False) | |
class IteratorInitializerHook(tf.train.SessionRunHook): | |
def __init__(self): | |
super(IteratorInitializerHook, self).__init__() | |
self.iterator_initiliser_func = None | |
def after_create_session(self, session, coord): | |
self.iterator_initiliser_func(session) | |
# Define the training inputs | |
def get_train_inputs(batch_size): | |
iterator_initiliser_hook = IteratorInitializerHook() | |
def train_inputs(): | |
with tf.name_scope('Training_data'): | |
images = mnist.train.images.reshape([-1, 28, 28, 1]) | |
labels = mnist.train.labels | |
images_placeholder = tf.placeholder( | |
images.dtype, images.shape) | |
labels_placeholder = tf.placeholder( | |
labels.dtype, labels.shape) | |
dataset = tf.contrib.data.Dataset.from_tensor_slices( | |
(images_placeholder, labels_placeholder)) | |
dataset = dataset.repeat(None) # Infinite | |
dataset = dataset.shuffle(buffer_size=10000) | |
dataset = dataset.batch(batch_size) | |
iterator = dataset.make_initializable_iterator() | |
next_example, next_label = iterator.get_next() | |
# Set runhook to initilise iterator | |
iterator_initiliser_hook.iterator_initiliser_func = \ | |
lambda sess: sess.run( | |
iterator.initializer, | |
feed_dict={images_placeholder: images, | |
labels_placeholder: labels}) | |
return next_example, next_label | |
return train_inputs, iterator_initiliser_hook | |
def get_test_inputs(batch_size): | |
iterator_initiliser_hook = IteratorInitializerHook() | |
def test_inputs(): | |
with tf.name_scope('Test_data'): | |
images = mnist.test.images.reshape([-1, 28, 28, 1]) | |
labels = mnist.test.labels | |
images_placeholder = tf.placeholder( | |
images.dtype, images.shape) | |
labels_placeholder = tf.placeholder( | |
labels.dtype, labels.shape) | |
dataset = tf.contrib.data.Dataset.from_tensor_slices( | |
(images_placeholder, labels_placeholder)) | |
dataset = dataset.batch(batch_size) | |
iterator = dataset.make_initializable_iterator() | |
next_example, next_label = iterator.get_next() | |
# Set runhook to initilise iterator | |
iterator_initiliser_hook.iterator_initiliser_func = \ | |
lambda sess: sess.run( | |
iterator.initializer, | |
feed_dict={images_placeholder: images, | |
labels_placeholder: labels}) | |
return next_example, next_label | |
return test_inputs, iterator_initiliser_hook | |
# Define model ############################################ | |
def get_estimator(run_config, hparams): | |
return tf.contrib.learn.Estimator( | |
model_fn=get_model_fn, | |
params=hparams, | |
config=run_config | |
) | |
def get_model_fn(inputs, targets, mode, params): | |
# Define model's architecture | |
logits = architecture(inputs, is_training=mode == ModeKeys.TRAIN) | |
head = tf.contrib.learn.multi_class_head( | |
n_classes=params.n_classes, | |
loss_fn=tf.losses.sparse_softmax_cross_entropy) | |
return head.create_model_fn_ops( | |
features={'inputs': inputs}, | |
labels=tf.cast(targets, tf.int32), | |
mode=mode, | |
train_op_fn=get_train_op_fn(params), | |
logits=logits | |
) | |
def get_train_op_fn(params): | |
def train_op_fn(loss): | |
return tf.contrib.layers.optimize_loss( | |
loss=loss, | |
global_step=tf.contrib.framework.get_global_step(), | |
optimizer=tf.train.AdamOptimizer, | |
learning_rate=params.learning_rate | |
) | |
return train_op_fn | |
def architecture(inputs, is_training, scope='MnistConvNet'): | |
tf.logging.debug('is_training: {}, {}'.format(type(is_training), is_training)) | |
with tf.variable_scope(scope): | |
with slim.arg_scope( | |
[slim.conv2d, slim.fully_connected], | |
weights_initializer=tf.contrib.layers.xavier_initializer()): | |
net = slim.conv2d(inputs, 20, [5, 5], padding='VALID', | |
scope='layer1-conv') | |
net = slim.max_pool2d(net, 2, stride=2, scope='layer2-max-pool') | |
net = slim.conv2d(net, 40, [5, 5], padding='VALID', | |
scope='layer3-conv') | |
net = slim.max_pool2d(net, 2, stride=2, scope='layer4-max-pool') | |
net = tf.reshape(net, [-1, 4 * 4 * 40]) | |
net = slim.fully_connected(net, 256, scope='layer5') | |
net = slim.dropout(net, is_training=is_training, | |
scope='layer5-dropout') | |
net = slim.fully_connected(net, 256, scope='layer6') | |
net = slim.dropout(net, is_training=is_training, | |
scope='layer6-dropout') | |
net = slim.fully_connected(net, 10, scope='output', | |
activation_fn=None) | |
return net | |
def create_experiment(run_config, hparams): | |
# You can change a subset of the run_config properties as | |
run_config = run_config.replace(save_checkpoints_steps=500) | |
# Define the mnist classifier | |
estimator = get_estimator(run_config, hparams) | |
# Setup data loaders | |
train_input_fn, train_input_hook = get_train_inputs(batch_size=128) | |
eval_input_fn, eval_input_hook = get_test_inputs(batch_size=128) | |
# Define the experiment | |
experiment = tf.contrib.learn.Experiment( | |
estimator=estimator, | |
train_input_fn=train_input_fn, | |
eval_input_fn=eval_input_fn, | |
train_steps=5000, | |
min_eval_frequency=500, | |
train_monitors=[train_input_hook], | |
eval_hooks=[eval_input_hook], | |
eval_steps=None # Use evaluation feeder until its empty | |
) | |
return experiment | |
def main(unused_argv=None): | |
model_dir = './mnist_training_new_dataset_hook' | |
# Define model parameters | |
hparams = tf.contrib.training.HParams( | |
learning_rate=0.002, | |
n_classes=10 | |
) | |
learn_runner.run( | |
experiment_fn=create_experiment, | |
run_config=tf.contrib.learn.RunConfig(model_dir=model_dir), | |
schedule="train_and_evaluate", | |
hparams=hparams) | |
if __name__ == "__main__": | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@wohlbier oh hey, I actually cannot recall at all. Sorry about that. Don’t know if this is helpful but I switched to tf.keras api altogether since TF2 since it seems to be more stable and better maintained.