Last active
October 28, 2017 15:01
-
-
Save bzamecnik/43cc3c682bf60b8d13bef7aa518b3dd4 to your computer and use it in GitHub Desktop.
Proof of concept of using Keras with StagingArea - data is fed separately via a tf.Variable and in keras Callback
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Is it possible to utilize Keras callbacks to encapsulate the logic? Yes. | |
# | |
# We decouple feeding inputs from StagingArea.put() - both can be called in | |
# a separate Session.run(). Thus it's not needed to hack Keras inputs too much. | |
# Instead in one run() we assign a numpy array to a Variable (via feed_dict) | |
# and in another run() we perform StagingArea.put(). | |
# | |
# We make a callback PrefetchCallback which perform the initial assign and put() | |
# in its on_epoch_begin() method. Then in each on_batch_begin() it just runs an | |
# assign. Then get() and put() is ran by Keras in the training function. | |
# | |
# It is able to slice the input array to batches and also for the last batch | |
# it provides a dummy value which is discarded, so that we can leave get() + put() | |
# uniform over all batches. | |
# | |
# Requires patches Keras: https://github.com/bzamecnik/keras/commit/8593309c371ce716fd039e33ed5ae4079096ee0f | |
import math | |
import tensorflow as tf | |
from keras.datasets import mnist | |
from keras.models import Model | |
from keras.layers import Dense, Input | |
from keras.utils import to_categorical | |
import numpy as np | |
import keras.backend as K | |
from keras.callbacks import Callback | |
num_classes = 10 | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
x_train = x_train.reshape(60000, 784).astype('float32') / 255 | |
y_train = to_categorical(y_train, num_classes).astype('float32') | |
batch_size = 64 # makes the last batch of size 32 | |
# last batch might be smaller | |
steps_per_epoch = int(math.ceil(len(x_train) / batch_size)) | |
features_shape = (None, 784) | |
labels_shape = (None, num_classes) | |
# for feeding inputs to the the StagingArea | |
# Let's try to decouple feeding data to StagingArea.put() | |
# from the training batch session.run() | |
# https://www.tensorflow.org/api_guides/python/reading_data#Preloaded_data | |
features_batch_next_value = tf.placeholder(dtype=tf.float32, shape=features_shape) | |
# - prevent the variable to be used as a model parameter: trainable=False, collections=[] | |
# - allow dynamic variable shape (for the last batch): validate_shape=False | |
features_batch_next = tf.Variable(features_batch_next_value, trainable=False, collections=[], validate_shape=False) | |
labels_batch_next_value = tf.placeholder(dtype=tf.float32, shape=labels_shape) | |
labels_batch_next = tf.Variable(labels_batch_next_value, trainable=False, collections=[], validate_shape=False) | |
assign_next_batch = tf.group(features_batch_next.initializer, labels_batch_next.initializer) | |
# will be used for prefetching to GPU | |
area = tf.contrib.staging.StagingArea( | |
dtypes=[tf.float32, tf.float32], | |
shapes=[features_shape, labels_shape]) | |
area_put = area.put([features_batch_next.value(), labels_batch_next.value()]) | |
area_get_features, area_get_labels = area.get() | |
area_size = area.size() | |
area_clear = area.clear() | |
image = Input(tensor=area_get_features) | |
x = Dense(512, activation='relu')(image) | |
digit = Dense(num_classes, activation='softmax')(x) | |
model = Model(inputs=image, outputs=digit) | |
class PrefetchCallback(Callback): | |
def __init__(self, x, y, batch_size): | |
self.x = x | |
self.y = y | |
self.batch_size = batch_size | |
self.steps_per_epoch = len(x) // batch_size | |
# 1 batch prefetched to the pipeline | |
self.prefetch_count = 1 | |
def _slice_batch(self, i): | |
start = i * self.batch_size | |
end = start + self.batch_size | |
return (self.x[start:end], self.y[start:end]) | |
def _assign_batch(self, session, data): | |
x_batch, y_batch = data | |
session.run(assign_next_batch, feed_dict={ | |
features_batch_next_value: x_batch, | |
labels_batch_next_value: y_batch}) | |
def on_epoch_begin(self, epoch, logs=None): | |
sess = K.get_session() | |
self._assign_batch(sess, self._slice_batch(0)) | |
sess.run(area_put) | |
def on_batch_begin(self, batch, logs=None): | |
sess = K.get_session() | |
if batch <= self.steps_per_epoch - self.prefetch_count: | |
data = self._slice_batch(batch + self.prefetch_count) | |
else: | |
# a dummy value for the last batch which is not used anyway | |
data = (np.zeros((batch_size, self.x.shape[1])), np.zeros((batch_size, self.y.shape[1]))) | |
self._assign_batch(sess, data) | |
def on_epoch_end(self, epoch, logs=None): | |
sess = K.get_session() | |
sess.run(area_clear) | |
model.compile(optimizer='sgd', loss='categorical_crossentropy', | |
target_tensors=[area_get_labels], fetches=[area_put]) | |
prefetch_callback = PrefetchCallback(x_train, y_train, batch_size) | |
model.fit(steps_per_epoch=steps_per_epoch, callbacks=[prefetch_callback]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment