Last active
August 17, 2020 08:28
-
-
Save FlorianMuellerklein/14d42e189ef2cc2bd5f686c1732d4d2b to your computer and use it in GitHub Desktop.
Quick way to multiprocess batch iteration for training deep learning models. Hacked together from Daniel Nouri's batch iterator https://github.com/dnouri/nolearn/blob/master/nolearn/lasagne/base.py#L70 and Jan Schülter's example https://github.com/Lasagne/Lasagne/issues/12#issuecomment-59494251
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing as mp | |
class threaded_batch_iter(object): | |
''' | |
Batch iterator to make transformations on the data. | |
Uses multiprocessing so that batches can be created on CPU while GPU runs previous batch | |
''' | |
def __init__(self, batchsize): | |
self.batchsize = batchsize | |
def __call__(self, X, y): | |
self.X, self.y = X, y | |
return self | |
def __iter__(self): | |
''' | |
multi thread the iter so that the GPU does not have to wait for the CPU to process data | |
runs the _gen_batches function in a seperate process so that it can be run while the GPU is running previous batch | |
''' | |
q = mp.Queue(maxsize=128) | |
def _gen_batches(): | |
num_samples = len(self.X) | |
idx = np.random.permutation(num_samples) | |
batches = range(0, num_samples - self.batchsize + 1, self.batchsize) | |
for batch in batches: | |
X_batch = self.X[idx[batch:batch + self.batchsize]] | |
y_batch = self.y[idx[batch:batch + self.batchsize]] | |
# do some stuff to the batches like augment images or load from folders | |
yield [X_batch, y_batch] | |
def _producer(_gen_batches): | |
# load the batch generator as a python generator | |
batch_gen = _gen_batches() | |
# loop over generator and put each batch into the queue | |
for data in batch_gen: | |
q.put(data, block=True) | |
# once the generator gets through all data issue the terminating command and close it | |
q.put(None) | |
q.close() | |
# start the producer in a seperate process and set the process as a daemon so it can quit easily if you ctrl-c | |
thread = mp.Process(target=_producer, args=[_gen_batches]) | |
thread.daemon = True | |
thread.start() | |
# grab each successive list containing X_batch and y_batch which were added to the queue by the generator | |
for data in iter(q.get, None): | |
yield data[0], data[1] | |
# ================================================================================================================== | |
# to use it do the following when looping over epochs | |
batch_iter = threaded_batch_iter(batchsize=128) | |
for epoch in range(epochs): | |
# ... | |
for X_batch, y_batch in batch_iter(X_train, y_train): | |
# ... |
Also is it possible to modify this so it can be used with Keras?
ValueError: Failed to find data adapter that can handle input: <class 'mp_batch_generator.threaded_batch_iter'>, <class 'NoneType'>
Is there a workaround or would I have to go under the hood in order to get this to work?
def multiprocessing_batch_generator(directory, direction):
batch_iter = mp_batch_generator.threaded_batch_iter(batchsize=128, directory=directory, direction=direction)
for X_batch, y_batch in batch_iter:
yield X_batch, y_batch
Tried it like this but I get AttributeError: Can't pickle local object 'threaded_batch_iter.__iter__.<locals>._producer'
Please not that if the last batch is smaller then the batchsize, this batch won't be processed.
The issue can be solved by adding the code block below after:
yield [X_batch, y_batch]
if i == len(batches) -1:
X_batch = self.X[idx[batch+self.batchsize:]]
y_batch = self.y[idx[batch+self.batchsize:]]
....
yield [X_batch, y_batch]
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Line 45: ...
args=[_gen_batches]
What gives?