Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save bzamecnik/43cc3c682bf60b8d13bef7aa518b3dd4 to your computer and use it in GitHub Desktop.
Save bzamecnik/43cc3c682bf60b8d13bef7aa518b3dd4 to your computer and use it in GitHub Desktop.
Proof of concept of using Keras with StagingArea - data is fed separately via a tf.Variable and in keras Callback
# Is it possible to utilize Keras callbacks to encapsulate the logic? Yes.
#
# We decouple feeding inputs from StagingArea.put() - both can be called in
# a separate Session.run(). Thus it's not needed to hack Keras inputs too much.
# Instead in one run() we assign a numpy array to a Variable (via feed_dict)
# and in another run() we perform StagingArea.put().
#
# We make a callback PrefetchCallback which perform the initial assign and put()
# in its on_epoch_begin() method. Then in each on_batch_begin() it just runs an
# assign. Then get() and put() is ran by Keras in the training function.
#
# It is able to slice the input array to batches and also for the last batch
# it provides a dummy value which is discarded, so that we can leave get() + put()
# uniform over all batches.
#
# Requires patches Keras: https://github.com/bzamecnik/keras/commit/8593309c371ce716fd039e33ed5ae4079096ee0f
import math
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
from keras.utils import to_categorical
import numpy as np
import keras.backend as K
from keras.callbacks import Callback
num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
y_train = to_categorical(y_train, num_classes).astype('float32')
batch_size = 64 # makes the last batch of size 32
# last batch might be smaller
steps_per_epoch = int(math.ceil(len(x_train) / batch_size))
features_shape = (None, 784)
labels_shape = (None, num_classes)
# for feeding inputs to the the StagingArea
# Let's try to decouple feeding data to StagingArea.put()
# from the training batch session.run()
# https://www.tensorflow.org/api_guides/python/reading_data#Preloaded_data
features_batch_next_value = tf.placeholder(dtype=tf.float32, shape=features_shape)
# - prevent the variable to be used as a model parameter: trainable=False, collections=[]
# - allow dynamic variable shape (for the last batch): validate_shape=False
features_batch_next = tf.Variable(features_batch_next_value, trainable=False, collections=[], validate_shape=False)
labels_batch_next_value = tf.placeholder(dtype=tf.float32, shape=labels_shape)
labels_batch_next = tf.Variable(labels_batch_next_value, trainable=False, collections=[], validate_shape=False)
assign_next_batch = tf.group(features_batch_next.initializer, labels_batch_next.initializer)
# will be used for prefetching to GPU
area = tf.contrib.staging.StagingArea(
dtypes=[tf.float32, tf.float32],
shapes=[features_shape, labels_shape])
area_put = area.put([features_batch_next.value(), labels_batch_next.value()])
area_get_features, area_get_labels = area.get()
area_size = area.size()
area_clear = area.clear()
image = Input(tensor=area_get_features)
x = Dense(512, activation='relu')(image)
digit = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=image, outputs=digit)
class PrefetchCallback(Callback):
def __init__(self, x, y, batch_size):
self.x = x
self.y = y
self.batch_size = batch_size
self.steps_per_epoch = len(x) // batch_size
# 1 batch prefetched to the pipeline
self.prefetch_count = 1
def _slice_batch(self, i):
start = i * self.batch_size
end = start + self.batch_size
return (self.x[start:end], self.y[start:end])
def _assign_batch(self, session, data):
x_batch, y_batch = data
session.run(assign_next_batch, feed_dict={
features_batch_next_value: x_batch,
labels_batch_next_value: y_batch})
def on_epoch_begin(self, epoch, logs=None):
sess = K.get_session()
self._assign_batch(sess, self._slice_batch(0))
sess.run(area_put)
def on_batch_begin(self, batch, logs=None):
sess = K.get_session()
if batch <= self.steps_per_epoch - self.prefetch_count:
data = self._slice_batch(batch + self.prefetch_count)
else:
# a dummy value for the last batch which is not used anyway
data = (np.zeros((batch_size, self.x.shape[1])), np.zeros((batch_size, self.y.shape[1])))
self._assign_batch(sess, data)
def on_epoch_end(self, epoch, logs=None):
sess = K.get_session()
sess.run(area_clear)
model.compile(optimizer='sgd', loss='categorical_crossentropy',
target_tensors=[area_get_labels], fetches=[area_put])
prefetch_callback = PrefetchCallback(x_train, y_train, batch_size)
model.fit(steps_per_epoch=steps_per_epoch, callbacks=[prefetch_callback])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment