Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Created October 28, 2017 21:36
Show Gist options
  • Save bzamecnik/b520e2b1e199b193b715477929e39b22 to your computer and use it in GitHub Desktop.
Save bzamecnik/b520e2b1e199b193b715477929e39b22 to your computer and use it in GitHub Desktop.
# Is it possible to utilize Keras callbacks to encapsulate the logic? Yes.
#
# We decouple feeding inputs from StagingArea.put() - both can be called in
# a separate Session.run(). Thus it's not needed to hack Keras inputs too much.
# Instead in one run() we assign a numpy array to a Variable (via feed_dict)
# and in another run() we perform StagingArea.put().
#
# We make a callback StagingAreaCallback which perform the initial assign and put()
# in its on_epoch_begin() method. Then in each on_batch_begin() it just runs an
# assign. Then get() and put() is ran by Keras in the training function.
#
# It is able to slice the input array to batches and also for the last batch
# it provides a dummy value which is discarded, so that we can leave get() + put()
# uniform over all batches.
#
# It supports variable batch size, so no worry if the dataset is not evenly divisible
# by the batch size.
#
# This is a more cleaned-up version of https://gist.github.com/bzamecnik/43cc3c682bf60b8d13bef7aa518b3dd4
import math
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
from keras.utils import to_categorical
import numpy as np
import keras.backend as K
from keras.callbacks import Callback
class StagingAreaCallback(Callback):
def __init__(self, x, y, batch_size, prefetch_count=1):
self.x = x
self.y = y
self.batch_size = batch_size
self.prefetch_count = prefetch_count
features_shape = (None,) + x.shape[1:]
labels_shape = (None,) + y.shape[1:]
with tf.device('/cpu:0'):
# for feeding inputs to the the StagingArea
# Let's try to decouple feeding data to StagingArea.put()
# from the training batch session.run()
# https://www.tensorflow.org/api_guides/python/reading_data#Preloaded_data
self.features_batch_next_value = tf.placeholder(dtype=x.dtype, shape=features_shape)
# - prevent the variable to be used as a model parameter: trainable=False, collections=[]
# - allow dynamic variable shape (for the last batch): validate_shape=False
features_batch_next = tf.Variable(self.features_batch_next_value, trainable=False, collections=[], validate_shape=False)
self.labels_batch_next_value = tf.placeholder(dtype=y.dtype, shape=labels_shape)
labels_batch_next = tf.Variable(self.labels_batch_next_value, trainable=False, collections=[], validate_shape=False)
self.assign_next_batch = tf.group(features_batch_next.initializer, labels_batch_next.initializer)
# will be used for prefetching to GPU
area = tf.contrib.staging.StagingArea(
dtypes=[x.dtype, y.dtype],
shapes=[features_shape, labels_shape])
self.area_put = area.put([features_batch_next.value(), labels_batch_next.value()])
area_get_features, area_get_labels = area.get()
self.area_size = area.size()
self.area_clear = area.clear()
self.input_tensor = area_get_features
self.target_tensor = area_get_labels
self.extra_ops = [self.area_put]
def set_params(self, params):
super().set_params(params)
self.steps_per_epoch = self.params['steps']
def _slice_batch(self, i):
start = i * self.batch_size
end = start + self.batch_size
return (self.x[start:end], self.y[start:end])
def _assign_batch(self, session, data):
x_batch, y_batch = data
session.run(self.assign_next_batch, feed_dict={
self.features_batch_next_value: x_batch,
self.labels_batch_next_value: y_batch})
def on_epoch_begin(self, epoch, logs=None):
sess = K.get_session()
for i in range(self.prefetch_count):
self._assign_batch(sess, self._slice_batch(i))
sess.run(self.area_put)
def on_batch_begin(self, batch, logs=None):
sess = K.get_session()
# Slice for `prefetch_count` last batches is empty.
# It serves as a dummy value which is put into StagingArea
# but never read.
data = self._slice_batch(batch + self.prefetch_count)
self._assign_batch(sess, data)
def on_epoch_end(self, epoch, logs=None):
sess = K.get_session()
sess.run(self.area_clear)
np.random.seed(42)
K.set_session(None)
tf.reset_default_graph()
num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
y_train = to_categorical(y_train, num_classes).astype('float32')
batch_size = 64 # makes the last batch of size 32
# last batch might be smaller
steps_per_epoch = int(math.ceil(len(x_train) / batch_size))
staging_area_callback = StagingAreaCallback(x_train, y_train, batch_size)
image = Input(tensor=staging_area_callback.input_tensor)
x = Dense(512, activation='relu')(image)
digit = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=image, outputs=digit)
model.compile(optimizer='sgd', loss='categorical_crossentropy',
target_tensors=[staging_area_callback.target_tensor], fetches=staging_area_callback.extra_ops)
model.fit(steps_per_epoch=steps_per_epoch, epochs=2, callbacks=[staging_area_callback])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment