Created
July 13, 2019 19:50
-
-
Save johnmeade/5dfbc50cc825a10f046219db071cdcac to your computer and use it in GitHub Desktop.
A Python multiprocessing pool drop-in replacement for the PyTorch DataLoader class
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
A multiprocessing Pool drop-in replacement for the pytorch | |
DataLoader class. Built to work around an apparent bug in | |
the default pytorch DataLoader, in which it hangs indefinitely. | |
It is possible to reach a sustained 95-100% GPU usage (as | |
reported by `nvidia-smi`) using this implementation. | |
Requirements: | |
pip install filelock | |
John Meade 2019 | |
MIT license | |
''' | |
import random | |
import multiprocessing as mp | |
# from multiprocessing.pool import ThreadPool | |
import traceback | |
from filelock import FileLock | |
ERR_FN = 'errors_data_loader.txt' | |
ERR_LOCK = FileLock('/tmp/gaff_dataloader_worker.lock') | |
def batches(xs, size): | |
'Chunk a list into batches of a given size' | |
stop = len(xs) | |
rng = list(range(0, stop, size)) + [None] | |
return [ xs[a:b] for a, b in zip(rng[:-1], rng[1:]) ] | |
def _worker(args): | |
try: | |
idxs, dataset, collate_fn, res_q = args | |
x = collate_fn([ dataset[i] for i in idxs ]) | |
res_q.put(x) | |
except KeyboardInterrupt: | |
pass | |
except: | |
print('[ERROR] exception caught in data loader worker, logged in:', ERR_FN) | |
with ERR_LOCK: | |
with open(ERR_FN, 'a') as f: | |
f.write('-'*50 + '\n\n' + traceback.format_exc() + '\n') | |
class PoolDataLoader: | |
def __init__(self, dataset, batch_size, collate_fn, | |
shuffle=True, num_workers=None): | |
''' | |
Simple multiprocessing Pool implementation of a dataloader. | |
Args: | |
num_workers: cpu count is used if this is None | |
''' | |
NW = num_workers or mp.cpu_count() | |
ND = len(dataset) | |
self.collate_fn = collate_fn | |
self.batch_size = batch_size | |
self.dataset = dataset | |
self.res_q = mp.Manager().Queue(NW) | |
self.pool = mp.Pool(NW) | |
# self.pool = ThreadPool(NW) | |
idxs = list(range(ND)) | |
if shuffle: | |
random.shuffle(idxs) | |
self.idx_batches = batches(idxs, batch_size) | |
def __del__(self): | |
self.pool.terminate() | |
self.pool.join() | |
def __len__(self): | |
return len(self.idx_batches) | |
def __iter__(self): | |
return _PoolDataLoaderIter(self) | |
class _PoolDataLoaderIter: | |
def __init__(self, pdl): | |
args = [ | |
(ib, pdl.dataset, pdl.collate_fn, pdl.res_q) | |
for ib in pdl.idx_batches | |
] | |
self.res = pdl.pool.map_async(_worker, args) | |
self.pdl = pdl | |
self._n = len(pdl) | |
def __del__(self): | |
del(self.pdl) | |
def __iter__(self): | |
return self | |
def __next__(self): | |
if self._n == 0: | |
raise StopIteration() | |
else: | |
self._n -= 1 | |
return self.pdl.res_q.get() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment