Internal multi-worker iterator for DataLoader.
在 num_workers
大于 0 的时候,DataLoader 实际用的是下面的代码来处理数据
# multi-worker
return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
prefetch=self._prefetch,
dataset=self._dataset if self._thread_pool else None,
data_loader=self)
又因为 self._thread_pool
是 False
,所以其实调用的是
# multi-worker
return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
worker_fn=_worker_fn, prefetch=self._prefetch, dataset=None,
data_loader=self)
这里的 _worker_fn
是 mxnet/gluon/data/dataloader.py
中定义的函数,具体如下:
def _worker_fn(samples, batchify_fn, dataset=None):
"""Function for processing data in worker process."""
# pylint: disable=unused-argument
# it is required that each worker process has to fork a new MXIndexedRecordIO handle
# preserving dataset as global variable can save tons of overhead and is safe in new process
global _worker_dataset
batch = batchify_fn([_worker_dataset[i] for i in samples])
buf = io.BytesIO()
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
return buf.getvalue()
multiprocessing.Pool
类是什么?进程池。Pool 类可以提供指定数量的进程供用户调用,当有新的请求提交到 Pool 中时,如果池还没有满,就会创建一个新的进程来执行请求。如果池满,请求就会告知先等待,直到池中有进程结束,才会创建新的进程来执行这些请求。
self._batch_sampler = batch_sampler
batch_sampler 是一个 batch size 长的 Dataset 样本的 index 的 list
self._iter = iter(self._batch_sampler)
因此,self._iter
是一个 index 的 Iterator
r = next(self._iter, None)
因此,r 某个是 Dataset index
async_ret = self._worker_pool.apply_async(
self._worker_fn, (r, self._batchify_fn, self._dataset))
Pool
类里面的 apply_async
方法,该函数用于传递不定参数,是非阻塞且支持结果返回进行回调,具体的代码如下
def apply_async(self, func, args=(), kwds={}, callback=None,
error_callback=None):
'''
Asynchronous version of `apply()` method.
'''
if self._state != RUN:
raise ValueError("Pool not running")
result = ApplyResult(self._cache, callback, error_callback)
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
return result
-
self._taskqueue = queue.Queue()
是一个队列 -
self._worker_fn = worker_fn
,如果