Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Last active February 17, 2019 23:13
Show Gist options
  • Save bzamecnik/b9dbd50cdc195d54513cd2f9dfb7e21b to your computer and use it in GitHub Desktop.
Save bzamecnik/b9dbd50cdc195d54513cd2f9dfb7e21b to your computer and use it in GitHub Desktop.
# It works!
#
# GTX 980 Ti
# plain model: ~14370 images/sec
# prefetch model: ~14670 images/sec
#
# In nvprof we can see that that HtoD memcpy is really async!
# What remains is just sync feed_dict to move from numpy to a CPU Variable.
import math
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input, Conv2D, MaxPooling2D, Dropout, Flatten
from keras.utils import to_categorical
import numpy as np
import keras.backend as K
from keras.callbacks import Callback
np.random.seed(42)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
K.set_session(sess)
# tf.reset_default_graph()
from keras.utils import to_categorical
def create_synth_dataset(image_size, class_count, dataset_size):
input_shape = (image_size, image_size, 3)
X = np.random.rand(*((dataset_size,) + input_shape)).astype('float32')
y = np.random.randint(low=0, high=class_count, size=dataset_size)
y = to_categorical(y, class_count).astype('float32')
return X, y
def create_synth_imagenet(image_size, dataset_size):
# image_size: typically 224 or 299
return create_synth_dataset(image_size=image_size, class_count=1000, dataset_size=dataset_size)
def create_synth_cifar10(dataset_size):
# x_train.shape == (dataset_size, 32, 32, 3), dtype('float64')
# y_train.shape == (dataset_size, 1), dtype('float64')
return create_synth_dataset(image_size=32, class_count=10, dataset_size=dataset_size)
def make_convnet(input):
x = Conv2D(32, (3, 3), padding='same', activation='relu')(input)
x = Conv2D(32, (3, 3), activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)
x = Conv2D(64, (3, 3), padding='same', activation='relu')(x)
x = Conv2D(64, (3, 3), activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Dropout(0.25)(x)
x = Flatten()(x)
x = Dense(512, activation='relu')(x)
x = Dropout(0.5)(x)
output = Dense(num_classes, activation='softmax')(x)
return output
def make_tensor_model(features_tensor, targets_tensor, extra_ops, num_classes):
input = Input(tensor=features_tensor)
model = Model(inputs=input, outputs=make_convnet(input))
model.compile(optimizer='sgd', loss='categorical_crossentropy',
target_tensors=[targets_tensor], fetches=extra_ops)
return model
def make_plain_model(input_shape, num_classes):
input = Input(shape=input_shape)
model = Model(inputs=input, outputs=make_convnet(input))
model.compile(optimizer='sgd', loss='categorical_crossentropy')
return model
num_classes = 10
dataset_size = 50000
batch_size = 2048
epochs = 5
x_train, y_train = create_synth_cifar10(dataset_size)
# last batch might be smaller
steps_per_epoch = int(math.ceil(len(x_train) / batch_size))
features_shape = (None, 32, 32, 3)
labels_shape = (None, num_classes)
with tf.device('/cpu:0'):
# for feeding inputs to the the StagingArea
# Let's try to decouple feeding data to StagingArea.put()
# from the training batch session.run()
# https://www.tensorflow.org/api_guides/python/reading_data#Preloaded_data
features_batch_next_value = tf.placeholder(dtype=tf.float32, shape=features_shape)
# - prevent the variable to be used as a model parameter: trainable=False, collections=[]
# - allow dynamic variable shape (for the last batch): validate_shape=False
features_batch_next = tf.Variable(features_batch_next_value, trainable=False, collections=[], validate_shape=False)
labels_batch_next_value = tf.placeholder(dtype=tf.float32, shape=labels_shape)
labels_batch_next = tf.Variable(labels_batch_next_value, trainable=False, collections=[], validate_shape=False)
assign_next_batch = tf.group(features_batch_next.initializer, labels_batch_next.initializer)
# will be used for prefetching to GPU
area = tf.contrib.staging.StagingArea(
dtypes=[tf.float32, tf.float32],
shapes=[features_shape, labels_shape])
area_put = area.put([features_batch_next.value(), labels_batch_next.value()])
area_get_features, area_get_labels = area.get()
area_size = area.size()
area_clear = area.clear()
class PrefetchCallback(Callback):
def __init__(self, x, y, batch_size, prefetch_count=1):
self.x = x
self.y = y
self.batch_size = batch_size
self.prefetch_count = prefetch_count
def set_params(self, params):
super(PrefetchCallback, self).set_params(params)
self.steps_per_epoch = self.params['steps']
def _slice_batch(self, i):
start = i * self.batch_size
end = start + self.batch_size
return (self.x[start:end], self.y[start:end])
def _assign_batch(self, session, data):
x_batch, y_batch = data
session.run(assign_next_batch, feed_dict={
features_batch_next_value: x_batch,
labels_batch_next_value: y_batch})
def on_epoch_begin(self, epoch, logs=None):
sess = K.get_session()
for i in range(self.prefetch_count):
self._assign_batch(sess, self._slice_batch(i))
sess.run(area_put)
def on_batch_begin(self, batch, logs=None):
sess = K.get_session()
# Slice for `prefetch_count` last batches is empty.
# It serves as a dummy value which is put into StagingArea
# but never read.
data = self._slice_batch(batch + self.prefetch_count)
self._assign_batch(sess, data)
def on_epoch_end(self, epoch, logs=None):
sess = K.get_session()
sess.run(area_clear)
import time
from keras.callbacks import Callback
class SamplesPerSec(Callback):
def __init__(self, batch_size):
self.batch_size = batch_size
def on_train_begin(self, logs={}):
self.all_samples_per_sec = []
def on_batch_begin(self, batch, logs={}):
self.start_time = time.time()
# self.batch_size = logs['size']
def on_batch_end(self, batch, logs={}):
end_time = time.time()
elapsed_time = end_time - self.start_time
samples_per_sec = self.batch_size / elapsed_time
self.all_samples_per_sec.append(samples_per_sec)
def on_epoch_end(self, epoch, logs={}):
self.print_results()
def on_train_end(self, logs={}):
self.print_results()
def print_results(self):
print('Samples/sec: %0.2f' % np.median(self.all_samples_per_sec))
print('training plain model:')
plain_model = make_plain_model(x_train.shape[1:], num_classes)
gauge = SamplesPerSec(batch_size)
plain_model.fit(x_train, y_train, batch_size, epochs=epochs, callbacks=[gauge])
print('training pipelined model:')
pipelined_model = make_tensor_model(area_get_features, area_get_labels, [area_put], num_classes)
prefetch_callback = PrefetchCallback(x_train, y_train, batch_size)
pipelined_model.fit(steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[prefetch_callback, gauge])
@bzamecnik
Copy link
Author

@bzamecnik
Copy link
Author

In TF 1.12 / Keras 2.2.4 it fails with:

training pipelined model:
Traceback (most recent call last):
  File "keras_staging_area_cifar10_convnet.py", line 188, in <module>
    pipelined_model.fit(steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[prefetch_callback, gauge])
  File "/home/bzamecnik/.virtualenvs/gpu_prefetch/local/lib/python2.7/site-packages/keras/engine/training.py", line 1010, in fit
    self._make_train_function()
  File "/home/bzamecnik/.virtualenvs/gpu_prefetch/local/lib/python2.7/site-packages/keras/engine/training.py", line 519, in _make_train_function
    **self._function_kwargs)
  File "/home/bzamecnik/.virtualenvs/gpu_prefetch/local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 2744, in function
    return Function(inputs, outputs, updates=updates, **kwargs)
  File "/home/bzamecnik/.virtualenvs/gpu_prefetch/local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 2567, in __init__
    self.fetches = [tf.identity(x) for x in self.fetches]
  File "/home/bzamecnik/.virtualenvs/gpu_prefetch/local/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 81, in identity
    return gen_array_ops.identity(input, name=name)
  File "/home/bzamecnik/.virtualenvs/gpu_prefetch/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3454, in identity
    "Identity", input=input, name=name)
  File "/home/bzamecnik/.virtualenvs/gpu_prefetch/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 513, in _apply_op_helper
    raise err
TypeError: Can't convert Operation 'StagingArea_put' to Tensor (target dtype=None, name=u'input', as_ref=False)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment