Created
October 28, 2017 21:36
-
-
Save bzamecnik/b520e2b1e199b193b715477929e39b22 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Is it possible to utilize Keras callbacks to encapsulate the logic? Yes. | |
# | |
# We decouple feeding inputs from StagingArea.put() - both can be called in | |
# a separate Session.run(). Thus it's not needed to hack Keras inputs too much. | |
# Instead in one run() we assign a numpy array to a Variable (via feed_dict) | |
# and in another run() we perform StagingArea.put(). | |
# | |
# We make a callback StagingAreaCallback which perform the initial assign and put() | |
# in its on_epoch_begin() method. Then in each on_batch_begin() it just runs an | |
# assign. Then get() and put() is ran by Keras in the training function. | |
# | |
# It is able to slice the input array to batches and also for the last batch | |
# it provides a dummy value which is discarded, so that we can leave get() + put() | |
# uniform over all batches. | |
# | |
# It supports variable batch size, so no worry if the dataset is not evenly divisible | |
# by the batch size. | |
# | |
# This is a more cleaned-up version of https://gist.github.com/bzamecnik/43cc3c682bf60b8d13bef7aa518b3dd4 | |
import math | |
import tensorflow as tf | |
from keras.datasets import mnist | |
from keras.models import Model | |
from keras.layers import Dense, Input | |
from keras.utils import to_categorical | |
import numpy as np | |
import keras.backend as K | |
from keras.callbacks import Callback | |
class StagingAreaCallback(Callback): | |
def __init__(self, x, y, batch_size, prefetch_count=1): | |
self.x = x | |
self.y = y | |
self.batch_size = batch_size | |
self.prefetch_count = prefetch_count | |
features_shape = (None,) + x.shape[1:] | |
labels_shape = (None,) + y.shape[1:] | |
with tf.device('/cpu:0'): | |
# for feeding inputs to the the StagingArea | |
# Let's try to decouple feeding data to StagingArea.put() | |
# from the training batch session.run() | |
# https://www.tensorflow.org/api_guides/python/reading_data#Preloaded_data | |
self.features_batch_next_value = tf.placeholder(dtype=x.dtype, shape=features_shape) | |
# - prevent the variable to be used as a model parameter: trainable=False, collections=[] | |
# - allow dynamic variable shape (for the last batch): validate_shape=False | |
features_batch_next = tf.Variable(self.features_batch_next_value, trainable=False, collections=[], validate_shape=False) | |
self.labels_batch_next_value = tf.placeholder(dtype=y.dtype, shape=labels_shape) | |
labels_batch_next = tf.Variable(self.labels_batch_next_value, trainable=False, collections=[], validate_shape=False) | |
self.assign_next_batch = tf.group(features_batch_next.initializer, labels_batch_next.initializer) | |
# will be used for prefetching to GPU | |
area = tf.contrib.staging.StagingArea( | |
dtypes=[x.dtype, y.dtype], | |
shapes=[features_shape, labels_shape]) | |
self.area_put = area.put([features_batch_next.value(), labels_batch_next.value()]) | |
area_get_features, area_get_labels = area.get() | |
self.area_size = area.size() | |
self.area_clear = area.clear() | |
self.input_tensor = area_get_features | |
self.target_tensor = area_get_labels | |
self.extra_ops = [self.area_put] | |
def set_params(self, params): | |
super().set_params(params) | |
self.steps_per_epoch = self.params['steps'] | |
def _slice_batch(self, i): | |
start = i * self.batch_size | |
end = start + self.batch_size | |
return (self.x[start:end], self.y[start:end]) | |
def _assign_batch(self, session, data): | |
x_batch, y_batch = data | |
session.run(self.assign_next_batch, feed_dict={ | |
self.features_batch_next_value: x_batch, | |
self.labels_batch_next_value: y_batch}) | |
def on_epoch_begin(self, epoch, logs=None): | |
sess = K.get_session() | |
for i in range(self.prefetch_count): | |
self._assign_batch(sess, self._slice_batch(i)) | |
sess.run(self.area_put) | |
def on_batch_begin(self, batch, logs=None): | |
sess = K.get_session() | |
# Slice for `prefetch_count` last batches is empty. | |
# It serves as a dummy value which is put into StagingArea | |
# but never read. | |
data = self._slice_batch(batch + self.prefetch_count) | |
self._assign_batch(sess, data) | |
def on_epoch_end(self, epoch, logs=None): | |
sess = K.get_session() | |
sess.run(self.area_clear) | |
np.random.seed(42) | |
K.set_session(None) | |
tf.reset_default_graph() | |
num_classes = 10 | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
x_train = x_train.reshape(60000, 784).astype('float32') / 255 | |
y_train = to_categorical(y_train, num_classes).astype('float32') | |
batch_size = 64 # makes the last batch of size 32 | |
# last batch might be smaller | |
steps_per_epoch = int(math.ceil(len(x_train) / batch_size)) | |
staging_area_callback = StagingAreaCallback(x_train, y_train, batch_size) | |
image = Input(tensor=staging_area_callback.input_tensor) | |
x = Dense(512, activation='relu')(image) | |
digit = Dense(num_classes, activation='softmax')(x) | |
model = Model(inputs=image, outputs=digit) | |
model.compile(optimizer='sgd', loss='categorical_crossentropy', | |
target_tensors=[staging_area_callback.target_tensor], fetches=staging_area_callback.extra_ops) | |
model.fit(steps_per_epoch=steps_per_epoch, epochs=2, callbacks=[staging_area_callback]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment