Skip to content

Instantly share code, notes, and snippets.

@YimianDai
Last active August 15, 2019 22:18
Show Gist options
  • Save YimianDai/6d0a1c123bebfa8fd0321b23f88645e1 to your computer and use it in GitHub Desktop.
Save YimianDai/6d0a1c123bebfa8fd0321b23f88645e1 to your computer and use it in GitHub Desktop.
DataLoader

DataLoader 的作用是 Loads data from a dataset and returns mini-batches of data

__init__

sampler 是什么?怎么理解 Either specify sampler or shuffle, not both.

samplerbatch_sampler 有什么区别?

如果指定 batch_sampler 的时候就不需要指定 batch_size, shuffle, sampler, and last_batch

batchify_fn 具体的功能是 merge samples into a batch,需要指定,因为有多种多样的 merge 方式

通常默认情况下,batch_samplersampler 都是 None

如果 shuffleTrue,就会运行 sampler = _sampler.RandomSampler(len(dataset)),如果是 False 就是 sampler = _sampler.SequentialSampler(len(dataset)),sampler 就是一个 Dataset index 的 Iterator,区别只是返回0,1,2,3...顺序的 index 还是被随机打散的

这里的 _sampler 是指 mxnet/gluon/data/sampler.py 这个文件,这个文件里面定义了 Sampler, SequentialSampler, RandomSampler, BatchSampler 这四个 class

            batch_sampler = _sampler.BatchSampler(
                sampler, batch_size, last_batch if last_batch else 'keep')

BatchSampler 就是返回 self._batch_size 长度的 index 集合,最简单的就是 iter 到 self._batch_size 长度的 index 集合,还有就是附带解决下 _last_batch 问题

batch_sampler 就是一些 Dataset 的样本的 index 集合,长度为 batch size

默认的 self._thread_pool 是 False,所以跑得是

                self._worker_pool = multiprocessing.Pool(
                    self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
  1. multiprocessing.Pool 是 Python 本身 multiprocessing/pool.py 中的 Pool class,因此 self._worker_poolmultiprocessing.Pool 类的一个实例。
  2. self._num_workers 这里默认的是 16
  3. 这里的 _worker_initializer 是在 mxnet/gluon/data/dataloader.py 中定义的函数,很奇怪,这个函数除了声明了一个全局变量 _worker_dataset 之外,也不返回什么呀?所以这里 call 这个函数只是为了创建这个全局变量吧
  4. self._dataset 就是传入的 VOCDetection 这样的 Dataset 类型
_worker_dataset = None
def _worker_initializer(dataset):
    """Initialier for processing pool."""
    # global dataset is per-process based and only available in worker processes
    # this is only necessary to handle MXIndexedRecordIO because otherwise dataset
    # can be passed as argument
    global _worker_dataset
    _worker_dataset = dataset

multiprocessing.Pool 类是什么?进程池。Pool 类可以提供指定数量的进程供用户调用,当有新的请求提交到 Pool 中时,如果池还没有满,就会创建一个新的进程来执行请求。如果池满,请求就会告知先等待,直到池中有进程结束,才会创建新的进程来执行这些请求。

Pool 类里面的 apply_async 方法,该函数用于传递不定参数,是非阻塞且支持结果返回进行回调

__iter__

if self._num_workers == 0

if self._num_workers != 0

        # multi-worker
        return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
                                pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
                                worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
                                prefetch=self._prefetch,
                                dataset=self._dataset if self._thread_pool else None,
                                data_loader=self)

因为 self._thread_pool 是 False,所以其实调用的是

        # multi-worker
        return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
                                pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
                                worker_fn=_worker_fn, prefetch=self._prefetch, dataset=None,
                                data_loader=self)

_MultiWorkerIter

Internal multi-worker iterator for DataLoader.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment