Last active
December 22, 2019 01:12
-
-
Save razhangwei/a832f43a9d71c9fa8a03600436eed356 to your computer and use it in GitHub Desktop.
ByLengthBatchSampler; bucketing data loader #PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ByLengthBatchSampler(torch.utils.data.Sampler): | |
"""Pseduo bucketed batch sampler. | |
Sample in a way that | |
Args: | |
lengths (list of int): the | |
batch_size (int): | |
drop_last (bool, optional): Defaults to False. [description] | |
""" | |
def __init__(self, lengths, batch_size, drop_last=False): | |
self.lengths = lengths | |
self.batch_size = batch_size | |
self.drop_last = drop_last | |
def __iter__(self): | |
# bucket sort; maintain random order inside each bucket | |
buckets = {} | |
for i, length in enumerate(self.lengths): | |
if length not in buckets: | |
buckets[length] = [i] | |
else: | |
buckets[length].append(i) | |
indices = [] | |
for length in sorted(buckets.keys()): | |
v = buckets[length] | |
random.shuffle(v) | |
indices += v | |
del buckets | |
index_batches = [] | |
for i in range(0, len(indices), self.batch_size): | |
j = min(i + self.batch_size, len(indices)) | |
index_batches.append(indices[i:j]) | |
if self.drop_last and len(index_batches[-1]) < self.batch_size: | |
index_batches = index_batches[:-1] | |
random.shuffle(index_batches) | |
for indices in index_batches: | |
yield indices | |
def __len__(self): | |
if self.drop_last: | |
return len(self.lengths) // self.batch_size | |
else: | |
return (len(self.lengths) + self.batch_size - 1) // self.batch_size | |
# Usage | |
if __name__ == "__main__": | |
lengths = [1, 1, 2, 3] | |
dataset = torch.utils.data.TensorDataset(torch.tensor(lengths)) | |
batch_sampler = ByLengthBatchSampler(lengths, 2) | |
for batch in batch_sampler: | |
print(batch) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_sampler=batch_sampler) | |
for batch in data_loader: | |
print(batch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment