Last active
October 25, 2017 23:31
-
-
Save bzamecnik/54104644baddc614230102e65f02b593 to your computer and use it in GitHub Desktop.
A rudimentary proof-of-concept of pipelining with TF StagingArea in Keras.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Works! In this snippet we're able to get batch from StagingArea | |
# and in parallel put another batch there which is load via | |
# feed_dict (provided as a tf.Placeholder wrapped as Keras Input). | |
# | |
# So far it doesn't handle splitting batches, handling borders of | |
# pipelining correctly or switching between training/validation set. | |
# | |
# But at least it's a proof of concept that it's possible to use | |
# StagingArea with Keras. | |
# | |
# Needs TF 1.3 and Keras 2.0.8 patched to provide fetches to | |
# tf.Session.run() | |
# https://github.com/bzamecnik/keras/commit/8593309c371ce716fd039e33ed5ae4079096ee0f | |
# | |
# More info about this topic: https://github.com/avolkov1/keras_experiments/issues/2 | |
import tensorflow as tf | |
from keras.datasets import mnist | |
from keras.models import Model | |
from keras.layers import Dense, Input | |
from keras.utils import to_categorical | |
import numpy as np | |
import keras.backend as K | |
K.set_session(None) | |
tf.reset_default_graph() | |
num_classes = 10 | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
x_train = x_train.reshape(60000, 784).astype('float32') / 255 | |
y_train = to_categorical(y_train, num_classes).astype('float32') | |
batch_size = 32 | |
features_shape = (batch_size, 784) | |
labels_shape = (batch_size, num_classes) | |
# for feeding inputs to the the StagingArea | |
features_batch_next = Input(shape=(784,)) | |
labels_batch_next = Input(shape=(num_classes,)) | |
# for prefetching to GPU | |
area = tf.contrib.staging.StagingArea( | |
dtypes=[tf.float32, tf.float32], | |
shapes=[features_shape, labels_shape]) | |
area_put = area.put([features_batch_next, labels_batch_next]) | |
area_get_features, area_get_labels = area.get() | |
area_size = area.size() | |
area_clear = area.clear() | |
image = Input(tensor=area_get_features) | |
x = Dense(512, activation='relu')(image) | |
digit = Dense(num_classes, activation='softmax')(x) | |
model = Model(inputs=[image, features_batch_next, labels_batch_next], outputs=digit) | |
model.compile(optimizer='sgd', loss='categorical_crossentropy', | |
target_tensors=[area_get_labels], fetches=[area_put]) | |
print('initial size:', K.get_session().run(area_size)) | |
K.get_session().run(area_put, feed_dict={ | |
features_batch_next: x_train[:batch_size], | |
labels_batch_next: y_train[:batch_size]}) | |
print('size after first put():', K.get_session().run(area_size)) | |
# model._make_train_function() | |
for i in range(5): | |
print('batch:', i) | |
# TODO: shift batches in x_train due to pipelining (skip first, add dummy last) | |
# NOTE: when providing tensor input automatic batching in Keras is not used | |
# model.train_function([x_train[:batch_size], y_train[:batch_size]]) | |
model.fit([x_train[:batch_size], y_train[:batch_size]], steps_per_epoch=1) | |
print('size after epoch %d:' % i, K.get_session().run(area_size)) | |
K.get_session().run(area_clear) | |
print('size at the end:', K.get_session().run(area_size)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
initial size: 0 | |
size after first put(): 1 | |
batch: 0 | |
Epoch 1/1 | |
1/1 [==============================] - ETA: 0s - loss: 2.3538 | |
size after epoch 0: 1 | |
batch: 1 | |
Epoch 1/1 | |
1/1 [==============================] - ETA: 0s - loss: 2.3102 | |
size after epoch 1: 1 | |
batch: 2 | |
Epoch 1/1 | |
1/1 [==============================] - ETA: 0s - loss: 2.2686 | |
size after epoch 2: 1 | |
batch: 3 | |
Epoch 1/1 | |
1/1 [==============================] - ETA: 0s - loss: 2.2286 | |
size after epoch 3: 1 | |
batch: 4 | |
Epoch 1/1 | |
1/1 [==============================] - ETA: 0s - loss: 2.1901 | |
size after epoch 4: 1 | |
size at the end: 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment