Skip to content

Instantly share code, notes, and snippets.

@kklemon
Last active November 1, 2024 09:04
Show Gist options
  • Save kklemon/c745e9ee2474f6907f2a3189c0da68b5 to your computer and use it in GitHub Desktop.
Save kklemon/c745e9ee2474f6907f2a3189c0da68b5 to your computer and use it in GitHub Desktop.
PyTorch IterableDataset implementation with multiprocessing and distributed training support
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import IterableDataset, DataLoader
class DistributedIterableDataset(IterableDataset):
"""
Example implementation of an IterableDataset that handles both multiprocessing (num_workers > 0)
and distributed training (nodes > 1).
For an indexable dataset the indices will typically be split upon all workers. However, for an iterable
dataset this is not trivially possible. This approach works by letting each worker iterate over the full
iterable but skip each element which is not at index `rank * num_workers + worker_id` where rank is the
node id in the MPI context and `worker_id` the worker id in the multiprocessing context of a single
DataLoader instance.
`rank` and `world_size` must be retrieved outside of the dataset's scope and passed to the constructor as they
won't be available reliably when the dataset is accessed from another process if `num_workers > 1`.
Note, that in this example the data, respectively the iterable is passed in the constructor. In practise,
this may lead to a large overhead as the iterable will typically perform expensive work such loading
or resizing images. As a worker will iterate over the whole iterable and thus invoke those expensive computations
but then discard most of the elements, the actual processing logic should either be moved to the `Dataset`-class
in which case the iterable would, for instance, only provide image paths instead of loading images directly
or the iterable could yield functions which then could be selectively called to perform the expensive work.
"""
def __init__(self, it, rank, world_size):
self.it = it
self.rank = rank
self.world_size = world_size
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
mod = self.world_size
shift = self.rank
if worker_info:
mod *= worker_info.num_workers
shift = self.rank * worker_info.num_workers + worker_info.id
for i in self.it:
if (i + shift) % mod == 0:
yield i
def distributed_worker(local_rank, world_size):
torch.distributed.init_process_group(
# The gloo backend is used as NCLL does not support the gather() function.
# This example will still work with NCLL but the verification logic below would need to be adjusted accordingly.
backend='gloo',
init_method='tcp://127.0.0.1:5678',
world_size=world_size,
rank=local_rank,
)
rank = dist.get_rank()
world_size = dist.get_world_size()
# Sample data
data = list(range(64))
dataset = DistributedIterableDataset(data, rank, world_size)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4)
# The training loop would normally be placed here.
# We retrieve all batches from the dataloader at once instead.
batches = torch.tensor(list(dataloader))
dst = 0
# Verify whether everything has worked correctly.
# This is done by gathering all batches each worker received and comparing the result with the original data.
# In the correct case, there should be no duplicates or missing elements.
if rank == dst:
tensor_list = [torch.empty(len(batches), dtype=torch.long) for _ in range(world_size)]
dist.gather(batches, tensor_list)
ret_data = list(sorted(torch.cat(tensor_list).tolist()))
if data == ret_data:
print('Original and retrieved samples are equal.')
else:
print('Original and retrieved samples are not equal. Something went wrong.')
else:
torch.distributed.gather(batches, dst=dst)
if __name__ == '__main__':
# Number of distributed workers, each spawned as process.
# This example does not require to have the actual number of GPUs available as the CPU based gloo backend is used.
n_procs = 4
process_context = mp.spawn(
distributed_worker,
nprocs=n_procs,
args=(n_procs,),
daemon=False,
join=True
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment