- In WMT datasets, there is wide variation in the length of examples. Some are one sentence. Some are 10 sentences.
- The max batch size that can fit on a v100 is roughly
(4, 512)
- you end up with lots of batches of shape
(4, 12)
or(4, small_int)
which don't fully utilize the GPU.
Dynamic Batch Size: try to organize batches to be 4*512=2048
tokens, one batch might be shaped (4,512)
another (32, 64)
.
pass
batch_sampler:List[List[int]] =[[id_0_batch0, id1_batch0], [id_3764_batch_1], [id_3_batch_2, id_4_batch_2, id_5_batch_2]
kwarg to DataLoader
.
Each entry in the list is the examples that compose a batch. The entries don't need to be the same length. procedure is
- (OPTIONAL) sort examples by length (to save padding)
- pack every entry in the list such that the included examples total at most
max_tokens=4000
, (this includes padding).
Then the batches are presented in different orders at training time. (but in the above example, id_0 and id_1 would be in the same batch every time).
This made training 40% faster for mbart wmt finetuning without changing metrics.
- have to tokenize full dataset to know how many tokens each example is. Currently we examples/seq2seq does tokenization lazily. You can assume that each token is 4 characters, which is conservative and roughly correct, but not clean.
- fairseq has a cython function that makes the batches, but it's in cython https://github.com/pytorch/fairseq/blob/da94e58c703866236b29242ae413146be69fe94f/fairseq/data/data_utils_fast.pyx#L27