Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Created October 28, 2017 21:06
Show Gist options
  • Save bzamecnik/12022c8dd50ec8eda0c2661830bbc8a4 to your computer and use it in GitHub Desktop.
Save bzamecnik/12022c8dd50ec8eda0c2661830bbc8a4 to your computer and use it in GitHub Desktop.
# GTX 980 Ti
# plain: 68.50 images/sec
# pipeline: 68.71 images/sec
import math
import tensorflow as tf
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input, Conv2D, MaxPooling2D, Dropout, Flatten
from keras.utils import to_categorical
import numpy as np
import keras.backend as K
from keras.callbacks import Callback
from keras.applications import ResNet50
np.random.seed(42)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
K.set_session(sess)
# tf.reset_default_graph()
from keras.utils import to_categorical
def create_synth_dataset(image_size, class_count, dataset_size):
input_shape = (image_size, image_size, 3)
X = np.random.rand(*((dataset_size,) + input_shape)).astype('float32')
y = np.random.randint(low=0, high=class_count, size=dataset_size)
y = to_categorical(y, class_count).astype('float32')
return X, y
def create_synth_imagenet(image_size, dataset_size):
# image_size: typically 224 or 299
return create_synth_dataset(image_size=image_size, class_count=1000, dataset_size=dataset_size)
def make_tensor_model(features_tensor, targets_tensor, extra_ops, num_classes):
model = ResNet50(input_tensor=features_tensor, classes=num_classes, weights=None)
model.compile(optimizer='sgd', loss='categorical_crossentropy',
target_tensors=[targets_tensor], fetches=extra_ops)
return model
def make_plain_model(num_classes):
model = ResNet50(classes=num_classes, weights=None)
model.compile(optimizer='sgd', loss='categorical_crossentropy')
return model
num_classes = 1000
dataset_size = 1024
batch_size = 32
epochs = 5
x_train, y_train = create_synth_imagenet(224, dataset_size)
# last batch might be smaller
steps_per_epoch = int(math.ceil(len(x_train) / batch_size))
features_shape = (None, 224, 224, 3)
labels_shape = (None, num_classes)
with tf.device('/cpu:0'):
# for feeding inputs to the the StagingArea
# Let's try to decouple feeding data to StagingArea.put()
# from the training batch session.run()
# https://www.tensorflow.org/api_guides/python/reading_data#Preloaded_data
features_batch_next_value = tf.placeholder(dtype=tf.float32, shape=features_shape)
# - prevent the variable to be used as a model parameter: trainable=False, collections=[]
# - allow dynamic variable shape (for the last batch): validate_shape=False
features_batch_next = tf.Variable(features_batch_next_value, trainable=False, collections=[], validate_shape=False)
labels_batch_next_value = tf.placeholder(dtype=tf.float32, shape=labels_shape)
labels_batch_next = tf.Variable(labels_batch_next_value, trainable=False, collections=[], validate_shape=False)
assign_next_batch = tf.group(features_batch_next.initializer, labels_batch_next.initializer)
# will be used for prefetching to GPU
area = tf.contrib.staging.StagingArea(
dtypes=[tf.float32, tf.float32],
shapes=[features_shape, labels_shape])
area_put = area.put([features_batch_next.value(), labels_batch_next.value()])
area_get_features, area_get_labels = area.get()
area_size = area.size()
area_clear = area.clear()
class PrefetchCallback(Callback):
def __init__(self, x, y, batch_size, prefetch_count=1):
self.x = x
self.y = y
self.batch_size = batch_size
self.prefetch_count = prefetch_count
def set_params(self, params):
super().set_params(params)
self.steps_per_epoch = self.params['steps']
def _slice_batch(self, i):
start = i * self.batch_size
end = start + self.batch_size
return (self.x[start:end], self.y[start:end])
def _assign_batch(self, session, data):
x_batch, y_batch = data
session.run(assign_next_batch, feed_dict={
features_batch_next_value: x_batch,
labels_batch_next_value: y_batch})
def on_epoch_begin(self, epoch, logs=None):
sess = K.get_session()
for i in range(self.prefetch_count):
self._assign_batch(sess, self._slice_batch(i))
sess.run(area_put)
def on_batch_begin(self, batch, logs=None):
sess = K.get_session()
# Slice for `prefetch_count` last batches is empty.
# It serves as a dummy value which is put into StagingArea
# but never read.
data = self._slice_batch(batch + self.prefetch_count)
self._assign_batch(sess, data)
def on_epoch_end(self, epoch, logs=None):
sess = K.get_session()
sess.run(area_clear)
import time
from keras.callbacks import Callback
class SamplesPerSec(Callback):
def __init__(self, batch_size):
self.batch_size = batch_size
def on_train_begin(self, logs={}):
self.all_samples_per_sec = []
def on_batch_begin(self, batch, logs={}):
self.start_time = time.time()
# self.batch_size = logs['size']
def on_batch_end(self, batch, logs={}):
end_time = time.time()
elapsed_time = end_time - self.start_time
samples_per_sec = self.batch_size / elapsed_time
self.all_samples_per_sec.append(samples_per_sec)
def on_epoch_end(self, epoch, logs={}):
self.print_results()
def on_train_end(self, logs={}):
self.print_results()
def print_results(self):
print('Samples/sec: %0.2f' % np.median(self.all_samples_per_sec))
print('training plain model:')
plain_model = make_plain_model(num_classes)
gauge = SamplesPerSec(batch_size)
plain_model.fit(x_train, y_train, batch_size, epochs=epochs, callbacks=[gauge])
print('training pipelined model:')
pipelined_model = make_tensor_model(area_get_features, area_get_labels, [area_put], num_classes)
prefetch_callback = PrefetchCallback(x_train, y_train, batch_size)
pipelined_model.fit(steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[prefetch_callback, gauge])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment