Skip to content

Instantly share code, notes, and snippets.

@erfannoury
Created January 17, 2017 19:03
Show Gist options
  • Save erfannoury/cc9f908cd4ae22ec7f9c74a035796edb to your computer and use it in GitHub Desktop.
Save erfannoury/cc9f908cd4ae22ec7f9c74a035796edb to your computer and use it in GitHub Desktop.
A multi-threaded dataset iterator
from __future__ import print_function, division
try:
import Queue
except ImportError:
import queue as Queue
import threading
import time
import numpy as np
class IteratorThread(threading.Thread):
"""
A thread with two queues for processing items
with a `process_batch` function and populating
the `out_queue` with the processed batch.
Parameters
----------
in_queue: Queue
the input queue that contains the items that
should be processed
out_queue: Queue
the output queue that will contain processed
items
process_batch: function
This is the function that will take items from
the `in_queue` as input, process them (which will
be a time-consuming task, therefore it will be
executed on a separate CPU thread while the GPU
is busy training the last batch), and output a
batch of data that will be fed to the neural
network.
"""
def __init__(self, in_queue, out_queue, process_batch):
super(IteratorThread, self).__init__()
self.in_queue = in_queue
self.out_queue = out_queue
self.process_batch = process_batch
def run(self):
"""
The function that will be called in this separate
thread.
"""
while True:
items = self.in_queue.get()
try:
batch = self.process_batch(items)
self.out_queue.put(batch)
self.in_queue.task_done()
except IOError:
pass
class BaseDatasetIterator(object):
"""
Generic iterator class for a dataset.
This class will spawn number of threads
to process data from the dataset and
prepare batches to be fed to the neural
network.
Parameters
----------
epoch_size: int
number of batches in an epoch
item_generator: function
a function that will return a generator for items
nthreads: int (default: 1)
number of threads for preparing batches
max_outq_size: int (default: 10)
maximum size of the out_queue. A large number
will consume more RAM, while a low number will
keep the threads idle most of the times.
inq_size: int (default: 10)
number of items to put into a particular input
queue at a time.
"""
def __init__(self, epoch_size, item_generator, nthreads=1, max_outq_size=10, inq_size=10):
if nthreads > max_outq_size:
raise ValueError("Number of threads should not larger than the"
"maximum size of the output queue")
self.nthreads = nthreads
self.max_outq_size = max_outq_size
self.inq_size = inq_size
self.n_consumed = 0
self.epoch_size = epoch_size
self.in_queues = [Queue.Queue() for _ in range(self.nthreads)]
self.out_queue = Queue.Queue(maxsize=self.max_outq_size)
self.item_generator = item_generator
self.init_threads()
self.reset()
def init_threads(self):
"""
This will create and initialize threads.
"""
self.threads = [IteratorThread(self.in_queues[i], self.out_queue,
self.process_batch) for i in range(self.nthreads)]
for th in self.threads:
th.setDaemon(True)
th.start()
def process_batch(self):
"""
Given a batch of items, it will process them
and create a batch of data that will be put in
the `out_queue` to be fed to the network.
"""
pass
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
"""
Yield the next item of the iterator
"""
return self.step()
def reset(self):
"""
Reset the iterator
"""
self.n_consumed = 0
self.item_gen = self.item_generator()
def step(self):
"""
This function will return an item from the `out_queue`.
In case any of the `in_queues` in any threads are empty,
it will also populate them.
"""
if self.n_consumed >= self.epoch_size:
self.reset()
raise StopIteration("End of epoch")
for inq in self.in_queues:
if inq.qsize() <= (self.max_outq_size // self.nthreads):
for _ in range(self.inq_size):
try:
inq.put(self.item_gen.next())
except StopIteration:
pass
batch = self.out_queue.get()
self.n_consumed += 1
return batch
if __name__ == '__main__':
class TestDataset(BaseDatasetIterator):
def __init__(self, batch_size=32, nthreads=2):
self.batch_size = batch_size
self.epoch_size = 100
super(TestDataset, self).__init__(self.epoch_size, self.generate_item, nthreads=nthreads)
def process_batch(self, x):
time.sleep(np.random.random())
return (x, 2*x)
def generate_item(self):
for _ in np.random.permutation(range(self.epoch_size)):
yield np.random.random((self.batch_size, 128))
td = TestDataset(nthreads=10)
epochs = 2
start_time = time.time()
for e in range(epochs):
for x, y in td:
print(x.shape, y.shape)
print('----------------------')
end_time = time.time()
print('Executing with {} threads took {} seconds'.format(td.nthreads, end_time - start_time))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment