Skip to content

Instantly share code, notes, and snippets.

@YimianDai
Last active August 15, 2019 22:36
Show Gist options
  • Save YimianDai/e388e121c00a89c7ef23bc7dae22b41c to your computer and use it in GitHub Desktop.
Save YimianDai/e388e121c00a89c7ef23bc7dae22b41c to your computer and use it in GitHub Desktop.
_MultiWorkerIter

Internal multi-worker iterator for DataLoader.

num_workers 大于 0 的时候,DataLoader 实际用的是下面的代码来处理数据

        # multi-worker
        return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
                                pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
                                worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
                                prefetch=self._prefetch,
                                dataset=self._dataset if self._thread_pool else None,
                                data_loader=self)

又因为 self._thread_poolFalse,所以其实调用的是

        # multi-worker
        return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
                                pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
                                worker_fn=_worker_fn, prefetch=self._prefetch, dataset=None,
                                data_loader=self)

这里的 _worker_fnmxnet/gluon/data/dataloader.py 中定义的函数,具体如下:

def _worker_fn(samples, batchify_fn, dataset=None):
    """Function for processing data in worker process."""
    # pylint: disable=unused-argument
    # it is required that each worker process has to fork a new MXIndexedRecordIO handle
    # preserving dataset as global variable can save tons of overhead and is safe in new process
    global _worker_dataset
    batch = batchify_fn([_worker_dataset[i] for i in samples])
    buf = io.BytesIO()
    ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
    return buf.getvalue()

multiprocessing.Pool 类是什么?进程池。Pool 类可以提供指定数量的进程供用户调用,当有新的请求提交到 Pool 中时,如果池还没有满,就会创建一个新的进程来执行请求。如果池满,请求就会告知先等待,直到池中有进程结束,才会创建新的进程来执行这些请求。

__init__

        self._batch_sampler = batch_sampler

batch_sampler 是一个 batch size 长的 Dataset 样本的 index 的 list

        self._iter = iter(self._batch_sampler)

因此,self._iter 是一个 index 的 Iterator

_push_next

        r = next(self._iter, None)

因此,r 某个是 Dataset index

        async_ret = self._worker_pool.apply_async(
            self._worker_fn, (r, self._batchify_fn, self._dataset))

Pool 类里面的 apply_async 方法,该函数用于传递不定参数,是非阻塞且支持结果返回进行回调,具体的代码如下

    def apply_async(self, func, args=(), kwds={}, callback=None,
            error_callback=None):
        '''
        Asynchronous version of `apply()` method.
        '''
        if self._state != RUN:
            raise ValueError("Pool not running")
        result = ApplyResult(self._cache, callback, error_callback)
        self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
        return result
  1. self._taskqueue = queue.Queue() 是一个队列

  2. self._worker_fn = worker_fn,如果

__next__

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment