Last active
December 21, 2024 15:51
-
-
Save airalcorn2/d08b43cc7f51bfb89cd63b784fb893e6 to your computer and use it in GitHub Desktop.
Investigating the behavior of PyTorch's DataLoader when using randomness to generate samples.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # This bug was fixed in PyTorch 1.9. See: | |
| # https://github.com/pytorch/pytorch/commit/aec83ff45ebd2cb3d4890cc97bffb1f367386392. | |
| # See: https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers | |
| # and: https://pytorch.org/docs/stable/data.html#multi-process-data-loading | |
| # and: https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading. | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| class RandomDataset(Dataset): | |
| def __len__(self): | |
| return 20 | |
| def __getitem__(self, idx): | |
| return torch.LongTensor([np.random.randint(20)]) | |
| def worker_init_fn(worker_id): | |
| # NumPy seed takes a 32-bit unsigned integer. | |
| np.random.seed(int(torch.utils.data.get_worker_info().seed) % (2 ** 32)) | |
| workers = 5 | |
| for seed_fn in [None, worker_init_fn]: | |
| dataset = RandomDataset() | |
| data_loader = DataLoader( | |
| dataset=dataset, | |
| batch_size=None, | |
| num_workers=workers, | |
| worker_init_fn=seed_fn, | |
| ) | |
| epoch2labels = [] | |
| for epoch in range(2): | |
| data = [] | |
| for tensor in data_loader: | |
| data.append(tensor[0].item()) | |
| epoch2labels.append(data) | |
| # When seed_fn is None, every five labels are the same because each of the five | |
| # workers is generating the same sequence. Further, the sequences for each of the two | |
| # epochs are identical. When seed_fn is worker_init_fn, the samples for each worker | |
| # are random and the samples for different epochs are different. | |
| for labels in epoch2labels: | |
| print(labels) | |
| print() |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Indeed, see the comment at the top of the script.