Skip to content

Instantly share code, notes, and snippets.

@johnmeade
Created July 13, 2019 19:50
Show Gist options
  • Save johnmeade/5dfbc50cc825a10f046219db071cdcac to your computer and use it in GitHub Desktop.
Save johnmeade/5dfbc50cc825a10f046219db071cdcac to your computer and use it in GitHub Desktop.
A Python multiprocessing pool drop-in replacement for the PyTorch DataLoader class
'''
A multiprocessing Pool drop-in replacement for the pytorch
DataLoader class. Built to work around an apparent bug in
the default pytorch DataLoader, in which it hangs indefinitely.
It is possible to reach a sustained 95-100% GPU usage (as
reported by `nvidia-smi`) using this implementation.
Requirements:
pip install filelock
John Meade 2019
MIT license
'''
import random
import multiprocessing as mp
# from multiprocessing.pool import ThreadPool
import traceback
from filelock import FileLock
ERR_FN = 'errors_data_loader.txt'
ERR_LOCK = FileLock('/tmp/gaff_dataloader_worker.lock')
def batches(xs, size):
'Chunk a list into batches of a given size'
stop = len(xs)
rng = list(range(0, stop, size)) + [None]
return [ xs[a:b] for a, b in zip(rng[:-1], rng[1:]) ]
def _worker(args):
try:
idxs, dataset, collate_fn, res_q = args
x = collate_fn([ dataset[i] for i in idxs ])
res_q.put(x)
except KeyboardInterrupt:
pass
except:
print('[ERROR] exception caught in data loader worker, logged in:', ERR_FN)
with ERR_LOCK:
with open(ERR_FN, 'a') as f:
f.write('-'*50 + '\n\n' + traceback.format_exc() + '\n')
class PoolDataLoader:
def __init__(self, dataset, batch_size, collate_fn,
shuffle=True, num_workers=None):
'''
Simple multiprocessing Pool implementation of a dataloader.
Args:
num_workers: cpu count is used if this is None
'''
NW = num_workers or mp.cpu_count()
ND = len(dataset)
self.collate_fn = collate_fn
self.batch_size = batch_size
self.dataset = dataset
self.res_q = mp.Manager().Queue(NW)
self.pool = mp.Pool(NW)
# self.pool = ThreadPool(NW)
idxs = list(range(ND))
if shuffle:
random.shuffle(idxs)
self.idx_batches = batches(idxs, batch_size)
def __del__(self):
self.pool.terminate()
self.pool.join()
def __len__(self):
return len(self.idx_batches)
def __iter__(self):
return _PoolDataLoaderIter(self)
class _PoolDataLoaderIter:
def __init__(self, pdl):
args = [
(ib, pdl.dataset, pdl.collate_fn, pdl.res_q)
for ib in pdl.idx_batches
]
self.res = pdl.pool.map_async(_worker, args)
self.pdl = pdl
self._n = len(pdl)
def __del__(self):
del(self.pdl)
def __iter__(self):
return self
def __next__(self):
if self._n == 0:
raise StopIteration()
else:
self._n -= 1
return self.pdl.res_q.get()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment