Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Last active October 18, 2017 11:54
Show Gist options
  • Save bzamecnik/f763d6b6d25dce571ddd6dcc3129daf3 to your computer and use it in GitHub Desktop.
Save bzamecnik/f763d6b6d25dce571ddd6dcc3129daf3 to your computer and use it in GitHub Desktop.
Keras + TensorFlow StagingArea double-buffer
# This is a minimal example of using TensorFlow's StagingArea with Keras
# with the goal to implement double-buffering of input batches at GPU.
#
# Basically we want to have an input batch ready in GPU memory when batch
# computation starts and copy another batch in parallel. It should avoid
# waiting for host-device memcpy and allow better saturation of the GPU
# compute. StagingArea is a queue implementation that can have it's buffer
# stored in GPU memory.
#
# https://www.tensorflow.org/api_docs/python/tf/contrib/staging/StagingArea
#
# In this basic example we just use StagingArea of size 2 - one batch being
# copied, another computed. Nothe this example is not complete for GPU usage
# in the sense that the batch is just a single constant.
#
# Mechanism:
# - before an epoch we put one batch to the queue
# - at each batch we get a batch from queue and compute it and in parallel
# we put next element into queue
# - in real situation at the and we should to put a dummy batch since we
# have both ops baked together but the dataset is already iterated
#
# As of version 2.0.8 Keras doesn't support passing additional operations
# to tf.Session.run() fetches argument. There exist a monkey patch from
# @avolkov1 and I'm working on a PR to pass other arguments of
# tf.Session.run(), including fetches and feed_dict.
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
from keras.utils import to_categorical
import keras.backend as K
num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
y_train = to_categorical(y_train, num_classes).astype('float32')
features_batch = tf.constant(x_train[:32])
labels_batch = tf.constant(y_train[:32])
area = tf.contrib.staging.StagingArea(dtypes=[tf.float32, tf.float32], shapes=[[32, 784], [32, 10]])
area_put = area.put([features_batch, labels_batch])
area_get_features, area_get_labels = area.get()
area_size = area.size()
image = Input(tensor=area_get_features)
x = Dense(512, activation='relu')(image)
digit = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=image, outputs=digit)
model.compile(optimizer='sgd', loss='categorical_crossentropy',
target_tensors=[area_get_labels], fetches=[area_put])
print('initial size:', K.get_session().run(area_size))
K.get_session().run(area_put)
print('size after first put():', K.get_session().run(area_size))
for i in range(5):
print('i:', i)
model.fit(steps_per_epoch=1)
print('size after epoch %d:' % i, K.get_session().run(area_size))
# An additional example where feed_dict as passed via K.Function to tf.Sesssion.run().
# In practice it seems useless, since the value can be defined only at Model.compile()
# time :((.
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
from keras.utils import to_categorical
import keras.backend as K
import numpy as np
num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
y_train = to_categorical(y_train, num_classes).astype('float32')
features_batch = tf.constant(x_train[:32])
labels_batch = tf.constant(y_train[:32])
area = tf.contrib.staging.StagingArea(
dtypes=[tf.float32, tf.float32],
shapes=[[32, 784], [32, 10]])
area_put = area.put([features_batch, labels_batch])
area_get_features, area_get_labels = area.get()
area_size = area.size()
image = Input(tensor=area_get_features)
x = Dense(512, activation='relu')(image)
digit = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=image, outputs=digit)
a = tf.placeholder(tf.float32, shape=(1024, 1024))
b = tf.matmul(a, a)
model.compile(optimizer='sgd', loss='categorical_crossentropy',
target_tensors=[area_get_labels], fetches=[area_put, b],
feed_dict={a: np.random.rand(1024, 1024)})
print('initial size:', K.get_session().run(area_size))
K.get_session().run(area_put)
print('size after first put():', K.get_session().run(area_size))
for i in range(5):
print('i:', i)
model.fit(steps_per_epoch=1)
print('size after epoch %d:' % i, K.get_session().run(area_size))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment