Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Last active October 25, 2017 23:31
Show Gist options
  • Save bzamecnik/54104644baddc614230102e65f02b593 to your computer and use it in GitHub Desktop.
Save bzamecnik/54104644baddc614230102e65f02b593 to your computer and use it in GitHub Desktop.
A rudimentary proof-of-concept of pipelining with TF StagingArea in Keras.
# Works! In this snippet we're able to get batch from StagingArea
# and in parallel put another batch there which is load via
# feed_dict (provided as a tf.Placeholder wrapped as Keras Input).
#
# So far it doesn't handle splitting batches, handling borders of
# pipelining correctly or switching between training/validation set.
#
# But at least it's a proof of concept that it's possible to use
# StagingArea with Keras.
#
# Needs TF 1.3 and Keras 2.0.8 patched to provide fetches to
# tf.Session.run()
# https://github.com/bzamecnik/keras/commit/8593309c371ce716fd039e33ed5ae4079096ee0f
#
# More info about this topic: https://github.com/avolkov1/keras_experiments/issues/2
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
from keras.utils import to_categorical
import numpy as np
import keras.backend as K
K.set_session(None)
tf.reset_default_graph()
num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
y_train = to_categorical(y_train, num_classes).astype('float32')
batch_size = 32
features_shape = (batch_size, 784)
labels_shape = (batch_size, num_classes)
# for feeding inputs to the the StagingArea
features_batch_next = Input(shape=(784,))
labels_batch_next = Input(shape=(num_classes,))
# for prefetching to GPU
area = tf.contrib.staging.StagingArea(
dtypes=[tf.float32, tf.float32],
shapes=[features_shape, labels_shape])
area_put = area.put([features_batch_next, labels_batch_next])
area_get_features, area_get_labels = area.get()
area_size = area.size()
area_clear = area.clear()
image = Input(tensor=area_get_features)
x = Dense(512, activation='relu')(image)
digit = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=[image, features_batch_next, labels_batch_next], outputs=digit)
model.compile(optimizer='sgd', loss='categorical_crossentropy',
target_tensors=[area_get_labels], fetches=[area_put])
print('initial size:', K.get_session().run(area_size))
K.get_session().run(area_put, feed_dict={
features_batch_next: x_train[:batch_size],
labels_batch_next: y_train[:batch_size]})
print('size after first put():', K.get_session().run(area_size))
# model._make_train_function()
for i in range(5):
print('batch:', i)
# TODO: shift batches in x_train due to pipelining (skip first, add dummy last)
# NOTE: when providing tensor input automatic batching in Keras is not used
# model.train_function([x_train[:batch_size], y_train[:batch_size]])
model.fit([x_train[:batch_size], y_train[:batch_size]], steps_per_epoch=1)
print('size after epoch %d:' % i, K.get_session().run(area_size))
K.get_session().run(area_clear)
print('size at the end:', K.get_session().run(area_size))
initial size: 0
size after first put(): 1
batch: 0
Epoch 1/1
1/1 [==============================] - ETA: 0s - loss: 2.3538
size after epoch 0: 1
batch: 1
Epoch 1/1
1/1 [==============================] - ETA: 0s - loss: 2.3102
size after epoch 1: 1
batch: 2
Epoch 1/1
1/1 [==============================] - ETA: 0s - loss: 2.2686
size after epoch 2: 1
batch: 3
Epoch 1/1
1/1 [==============================] - ETA: 0s - loss: 2.2286
size after epoch 3: 1
batch: 4
Epoch 1/1
1/1 [==============================] - ETA: 0s - loss: 2.1901
size after epoch 4: 1
size at the end: 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment