Created
January 17, 2017 19:03
-
-
Save erfannoury/cc9f908cd4ae22ec7f9c74a035796edb to your computer and use it in GitHub Desktop.
A multi-threaded dataset iterator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import print_function, division | |
| try: | |
| import Queue | |
| except ImportError: | |
| import queue as Queue | |
| import threading | |
| import time | |
| import numpy as np | |
| class IteratorThread(threading.Thread): | |
| """ | |
| A thread with two queues for processing items | |
| with a `process_batch` function and populating | |
| the `out_queue` with the processed batch. | |
| Parameters | |
| ---------- | |
| in_queue: Queue | |
| the input queue that contains the items that | |
| should be processed | |
| out_queue: Queue | |
| the output queue that will contain processed | |
| items | |
| process_batch: function | |
| This is the function that will take items from | |
| the `in_queue` as input, process them (which will | |
| be a time-consuming task, therefore it will be | |
| executed on a separate CPU thread while the GPU | |
| is busy training the last batch), and output a | |
| batch of data that will be fed to the neural | |
| network. | |
| """ | |
| def __init__(self, in_queue, out_queue, process_batch): | |
| super(IteratorThread, self).__init__() | |
| self.in_queue = in_queue | |
| self.out_queue = out_queue | |
| self.process_batch = process_batch | |
| def run(self): | |
| """ | |
| The function that will be called in this separate | |
| thread. | |
| """ | |
| while True: | |
| items = self.in_queue.get() | |
| try: | |
| batch = self.process_batch(items) | |
| self.out_queue.put(batch) | |
| self.in_queue.task_done() | |
| except IOError: | |
| pass | |
| class BaseDatasetIterator(object): | |
| """ | |
| Generic iterator class for a dataset. | |
| This class will spawn number of threads | |
| to process data from the dataset and | |
| prepare batches to be fed to the neural | |
| network. | |
| Parameters | |
| ---------- | |
| epoch_size: int | |
| number of batches in an epoch | |
| item_generator: function | |
| a function that will return a generator for items | |
| nthreads: int (default: 1) | |
| number of threads for preparing batches | |
| max_outq_size: int (default: 10) | |
| maximum size of the out_queue. A large number | |
| will consume more RAM, while a low number will | |
| keep the threads idle most of the times. | |
| inq_size: int (default: 10) | |
| number of items to put into a particular input | |
| queue at a time. | |
| """ | |
| def __init__(self, epoch_size, item_generator, nthreads=1, max_outq_size=10, inq_size=10): | |
| if nthreads > max_outq_size: | |
| raise ValueError("Number of threads should not larger than the" | |
| "maximum size of the output queue") | |
| self.nthreads = nthreads | |
| self.max_outq_size = max_outq_size | |
| self.inq_size = inq_size | |
| self.n_consumed = 0 | |
| self.epoch_size = epoch_size | |
| self.in_queues = [Queue.Queue() for _ in range(self.nthreads)] | |
| self.out_queue = Queue.Queue(maxsize=self.max_outq_size) | |
| self.item_generator = item_generator | |
| self.init_threads() | |
| self.reset() | |
| def init_threads(self): | |
| """ | |
| This will create and initialize threads. | |
| """ | |
| self.threads = [IteratorThread(self.in_queues[i], self.out_queue, | |
| self.process_batch) for i in range(self.nthreads)] | |
| for th in self.threads: | |
| th.setDaemon(True) | |
| th.start() | |
| def process_batch(self): | |
| """ | |
| Given a batch of items, it will process them | |
| and create a batch of data that will be put in | |
| the `out_queue` to be fed to the network. | |
| """ | |
| pass | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| return self.next() | |
| def next(self): | |
| """ | |
| Yield the next item of the iterator | |
| """ | |
| return self.step() | |
| def reset(self): | |
| """ | |
| Reset the iterator | |
| """ | |
| self.n_consumed = 0 | |
| self.item_gen = self.item_generator() | |
| def step(self): | |
| """ | |
| This function will return an item from the `out_queue`. | |
| In case any of the `in_queues` in any threads are empty, | |
| it will also populate them. | |
| """ | |
| if self.n_consumed >= self.epoch_size: | |
| self.reset() | |
| raise StopIteration("End of epoch") | |
| for inq in self.in_queues: | |
| if inq.qsize() <= (self.max_outq_size // self.nthreads): | |
| for _ in range(self.inq_size): | |
| try: | |
| inq.put(self.item_gen.next()) | |
| except StopIteration: | |
| pass | |
| batch = self.out_queue.get() | |
| self.n_consumed += 1 | |
| return batch | |
| if __name__ == '__main__': | |
| class TestDataset(BaseDatasetIterator): | |
| def __init__(self, batch_size=32, nthreads=2): | |
| self.batch_size = batch_size | |
| self.epoch_size = 100 | |
| super(TestDataset, self).__init__(self.epoch_size, self.generate_item, nthreads=nthreads) | |
| def process_batch(self, x): | |
| time.sleep(np.random.random()) | |
| return (x, 2*x) | |
| def generate_item(self): | |
| for _ in np.random.permutation(range(self.epoch_size)): | |
| yield np.random.random((self.batch_size, 128)) | |
| td = TestDataset(nthreads=10) | |
| epochs = 2 | |
| start_time = time.time() | |
| for e in range(epochs): | |
| for x, y in td: | |
| print(x.shape, y.shape) | |
| print('----------------------') | |
| end_time = time.time() | |
| print('Executing with {} threads took {} seconds'.format(td.nthreads, end_time - start_time)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment