Created
November 30, 2021 12:31
-
-
Save ozancaglayan/f3b75758836f0086cb129280f9a75c72 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Optional, Tuple, Dict | |
from torch.utils.data import Dataset | |
from torch.utils.data.sampler import RandomSampler, Sampler | |
class MultiTaskBatchSampler(Sampler): | |
def __init__(self, datasets: Dict[str, Dataset], | |
batch_size: int, | |
epoch_size: Optional[int] = None, | |
batch_distribution: str = 'prop', | |
epoch_size_mode: Optional[str] = 'min'): | |
assert batch_distribution in ('prop', 'inv_prop', 'equal'), \ | |
"batch_distribution should be either `prop`, `inv_prop` or `equal`." | |
if epoch_size and epoch_size_mode: | |
assert epoch_size_mode in ('min', 'max'), \ | |
"Epoch size mode can be one of `min` or `max`." | |
raise RuntimeError('`epoch_size` and `epoch_size_mode` arguments are mutually exclusive.') | |
if not (epoch_size or epoch_size_mode): | |
raise RuntimeError('Either `epoch_size` or `epoch_size_mode` should be provided.') | |
self.datasets = datasets | |
self.epoch_size_mode = epoch_size_mode | |
self.batch_distribution = batch_distribution | |
self.batch_size = batch_size | |
self.batch_sizes = {} | |
# Get per-task dataset sizes | |
self.sizes = {k: len(v) for (k, v) in datasets.items()} | |
# Have one sampler per dataset | |
self.samplers = {k: RandomSampler(v) for (k, v) in datasets.items()} | |
# Get an iterator for each sampler | |
self._iters = {k: iter(v) for (k, v) in self.samplers.items()} | |
if epoch_size: | |
self.epoch_size = epoch_size | |
elif self.epoch_size_mode == 'min': | |
self.epoch_size = min(self.sizes.values()) | |
elif self.epoch_size_mode == 'max': | |
self.epoch_size = max(self.sizes.values()) | |
# Store some information | |
self._total_size = sum(self.sizes.values()) | |
self.task_ratios = {k: len(v) / self._total_size for (k, v) in datasets.items()} | |
if self.batch_distribution == 'inv_prop': | |
self.task_ratios = {k: 1.0 - v for (k, v) in datasets.items()} | |
# Actual batch size can be <= requested one if not divisible by # tasks | |
if self.batch_distribution == 'equal': | |
equal_batch_size = self.batch_size // len(datasets) | |
self.batch_sizes = {k: equal_batch_size for k in datasets} | |
# Change effective batch size which can be <= requested size | |
self.batch_size = equal_batch_size * len(datasets) | |
else: | |
self.batch_sizes = {k: int(self.batch_size * v) for (k, v) in self.task_ratios.items()} | |
# Register | |
self.task_epochs_done = {k: 0 for k in self.datasets} | |
def _reset_sampler(self, name: str) -> None: | |
self._iters[name] = iter(self.samplers[name]) | |
self.task_epochs_done[name] += 1 | |
def __iter__(self): | |
batch = [] | |
n_processed = 0 | |
while n_processed < self.epoch_size: | |
# Iterate over datasets | |
for name in self._iters: | |
_iter = self._iters[name] | |
for _ in range(self.batch_sizes[name]): | |
# Fetch a new sample | |
idx = next(_iter, None) | |
if idx is None: | |
# Iterator/dataset worked through its epoch | |
self._reset_sampler(name) | |
break | |
else: | |
batch.append((name, idx)) | |
n_processed += len(batch) | |
yield batch | |
class TestDataset(Dataset): | |
def __init__(self, name: str, size: int): | |
self._data = [f'{name}_{idx}' for idx in range(size)] | |
def __getitem__(self, idx): | |
return self._data[idx] | |
def __len__(self): | |
return len(self._data) | |
if __name__ == '__main__': | |
import random | |
random.seed(923) | |
datasets = {} | |
batches = {} | |
tasks = { | |
'a': 100, | |
'b': 200, | |
'c': 300, | |
'd': 400, | |
} | |
for name in ('a', 'b', 'c', 'd'): | |
#_size = random.randint(10000, 100000) | |
_size = tasks[name] | |
datasets[name] = TestDataset(name, size=_size) | |
batches[name] = [] | |
sampler = MultiTaskBatchSampler(datasets, batch_size=64, | |
batch_distribution='equal', | |
epoch_size=int(1e6), epoch_size_mode=None) | |
print('Sizes: ', sampler.sizes) | |
ep_count = 20 | |
for ep in range(ep_count): | |
for batch in sampler: | |
for name, elem in batch: | |
batches[name].append(elem) | |
print('Epochs done: ', ep_count) | |
print('Epoch size: ', sampler.epoch_size) | |
print('Task batch sizes: ', sampler.batch_sizes) | |
print('Task ratios: ', sampler.task_ratios) | |
print('Task epochs done: ', sampler.task_epochs_done) | |
print('Task sample counts: ', {k: len(v) for (k, v) in batches.items()}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment