Skip to content

Instantly share code, notes, and snippets.

@ozancaglayan
Created November 30, 2021 12:31
Show Gist options
  • Save ozancaglayan/f3b75758836f0086cb129280f9a75c72 to your computer and use it in GitHub Desktop.
Save ozancaglayan/f3b75758836f0086cb129280f9a75c72 to your computer and use it in GitHub Desktop.
from typing import List, Optional, Tuple, Dict
from torch.utils.data import Dataset
from torch.utils.data.sampler import RandomSampler, Sampler
class MultiTaskBatchSampler(Sampler):
def __init__(self, datasets: Dict[str, Dataset],
batch_size: int,
epoch_size: Optional[int] = None,
batch_distribution: str = 'prop',
epoch_size_mode: Optional[str] = 'min'):
assert batch_distribution in ('prop', 'inv_prop', 'equal'), \
"batch_distribution should be either `prop`, `inv_prop` or `equal`."
if epoch_size and epoch_size_mode:
assert epoch_size_mode in ('min', 'max'), \
"Epoch size mode can be one of `min` or `max`."
raise RuntimeError('`epoch_size` and `epoch_size_mode` arguments are mutually exclusive.')
if not (epoch_size or epoch_size_mode):
raise RuntimeError('Either `epoch_size` or `epoch_size_mode` should be provided.')
self.datasets = datasets
self.epoch_size_mode = epoch_size_mode
self.batch_distribution = batch_distribution
self.batch_size = batch_size
self.batch_sizes = {}
# Get per-task dataset sizes
self.sizes = {k: len(v) for (k, v) in datasets.items()}
# Have one sampler per dataset
self.samplers = {k: RandomSampler(v) for (k, v) in datasets.items()}
# Get an iterator for each sampler
self._iters = {k: iter(v) for (k, v) in self.samplers.items()}
if epoch_size:
self.epoch_size = epoch_size
elif self.epoch_size_mode == 'min':
self.epoch_size = min(self.sizes.values())
elif self.epoch_size_mode == 'max':
self.epoch_size = max(self.sizes.values())
# Store some information
self._total_size = sum(self.sizes.values())
self.task_ratios = {k: len(v) / self._total_size for (k, v) in datasets.items()}
if self.batch_distribution == 'inv_prop':
self.task_ratios = {k: 1.0 - v for (k, v) in datasets.items()}
# Actual batch size can be <= requested one if not divisible by # tasks
if self.batch_distribution == 'equal':
equal_batch_size = self.batch_size // len(datasets)
self.batch_sizes = {k: equal_batch_size for k in datasets}
# Change effective batch size which can be <= requested size
self.batch_size = equal_batch_size * len(datasets)
else:
self.batch_sizes = {k: int(self.batch_size * v) for (k, v) in self.task_ratios.items()}
# Register
self.task_epochs_done = {k: 0 for k in self.datasets}
def _reset_sampler(self, name: str) -> None:
self._iters[name] = iter(self.samplers[name])
self.task_epochs_done[name] += 1
def __iter__(self):
batch = []
n_processed = 0
while n_processed < self.epoch_size:
# Iterate over datasets
for name in self._iters:
_iter = self._iters[name]
for _ in range(self.batch_sizes[name]):
# Fetch a new sample
idx = next(_iter, None)
if idx is None:
# Iterator/dataset worked through its epoch
self._reset_sampler(name)
break
else:
batch.append((name, idx))
n_processed += len(batch)
yield batch
class TestDataset(Dataset):
def __init__(self, name: str, size: int):
self._data = [f'{name}_{idx}' for idx in range(size)]
def __getitem__(self, idx):
return self._data[idx]
def __len__(self):
return len(self._data)
if __name__ == '__main__':
import random
random.seed(923)
datasets = {}
batches = {}
tasks = {
'a': 100,
'b': 200,
'c': 300,
'd': 400,
}
for name in ('a', 'b', 'c', 'd'):
#_size = random.randint(10000, 100000)
_size = tasks[name]
datasets[name] = TestDataset(name, size=_size)
batches[name] = []
sampler = MultiTaskBatchSampler(datasets, batch_size=64,
batch_distribution='equal',
epoch_size=int(1e6), epoch_size_mode=None)
print('Sizes: ', sampler.sizes)
ep_count = 20
for ep in range(ep_count):
for batch in sampler:
for name, elem in batch:
batches[name].append(elem)
print('Epochs done: ', ep_count)
print('Epoch size: ', sampler.epoch_size)
print('Task batch sizes: ', sampler.batch_sizes)
print('Task ratios: ', sampler.task_ratios)
print('Task epochs done: ', sampler.task_epochs_done)
print('Task sample counts: ', {k: len(v) for (k, v) in batches.items()})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment