Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Last active December 22, 2019 01:12
Show Gist options
  • Save razhangwei/a832f43a9d71c9fa8a03600436eed356 to your computer and use it in GitHub Desktop.
Save razhangwei/a832f43a9d71c9fa8a03600436eed356 to your computer and use it in GitHub Desktop.
ByLengthBatchSampler; bucketing data loader #PyTorch
class ByLengthBatchSampler(torch.utils.data.Sampler):
"""Pseduo bucketed batch sampler.
Sample in a way that
Args:
lengths (list of int): the
batch_size (int):
drop_last (bool, optional): Defaults to False. [description]
"""
def __init__(self, lengths, batch_size, drop_last=False):
self.lengths = lengths
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
# bucket sort; maintain random order inside each bucket
buckets = {}
for i, length in enumerate(self.lengths):
if length not in buckets:
buckets[length] = [i]
else:
buckets[length].append(i)
indices = []
for length in sorted(buckets.keys()):
v = buckets[length]
random.shuffle(v)
indices += v
del buckets
index_batches = []
for i in range(0, len(indices), self.batch_size):
j = min(i + self.batch_size, len(indices))
index_batches.append(indices[i:j])
if self.drop_last and len(index_batches[-1]) < self.batch_size:
index_batches = index_batches[:-1]
random.shuffle(index_batches)
for indices in index_batches:
yield indices
def __len__(self):
if self.drop_last:
return len(self.lengths) // self.batch_size
else:
return (len(self.lengths) + self.batch_size - 1) // self.batch_size
# Usage
if __name__ == "__main__":
lengths = [1, 1, 2, 3]
dataset = torch.utils.data.TensorDataset(torch.tensor(lengths))
batch_sampler = ByLengthBatchSampler(lengths, 2)
for batch in batch_sampler:
print(batch)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler)
for batch in data_loader:
print(batch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment