Last active
March 17, 2021 18:18
-
-
Save tmbdev/ad3cc45c7ff86fcebde585f3b073d721 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# In addition to torch and torchvision, you need to pip3 intsall typer and webdataset. | |
import sys | |
import os | |
import warnings | |
import torch | |
import webdataset as wds | |
import typer | |
from itertools import islice | |
from torchvision import transforms | |
# We're not using actual torch.distributed, since we just want to simulate | |
# how data is split between different nodes. Other than that, though, this | |
# code works the same way as true distributed code. | |
dist_rank = -1 | |
dist_size = -1 | |
show_splits = False | |
def split_by_node(urls): | |
"""Split urls for each node. | |
This uses the rank and world size. Note that it is invoked in each worker, | |
so the results need to be consistent between multiple invocations.""" | |
global dist_rank, dist_size | |
if dist_rank >= 0 and dist_size > 0: | |
result = urls[dist_rank::dist_size] | |
if show_splits: | |
print( | |
f"split_by_node {dist_rank}/{dist_size} len={len(result)}", | |
file=sys.stderr, | |
) | |
return result | |
else: | |
print(f"single node len={len(result)}") | |
return urls | |
def split_by_worker(urls): | |
"""Split urls for each worker.""" | |
urls = [url for url in urls] | |
worker_info = torch.utils.data.get_worker_info() | |
if worker_info is not None: | |
wid = worker_info.id | |
num_workers = worker_info.num_workers | |
if wid == 0 and len(urls) < num_workers: | |
warnings.warn(f"num_workers {num_workers} > num_shards {len(urls)}") | |
result = urls[wid::num_workers] | |
if show_splits: | |
print( | |
f"split_by_worker {wid}/{num_workers} len={len(result)}", | |
file=sys.stderr, | |
) | |
return result | |
else: | |
return urls | |
def make_loader(shards, batch_size=128, num_workers=6, partial=False, repeat=1): | |
"""Create a loader for Imagenet-like data. | |
The `partial` argument is passed on to the `batched()` method. | |
Note that if `partial` is True, each worker may return a partial batch.""" | |
augment = transforms.Compose( | |
[ | |
transforms.RandomResizedCrop(224), | |
transforms.ToTensor(), | |
] | |
) | |
dataset = ( | |
wds.WebDataset(shards, nodesplitter=split_by_node, splitter=split_by_worker) | |
.shuffle(1000) | |
.decode("pil") | |
.to_tuple("jpg", "cls") | |
.map_tuple(augment) | |
.batched(batch_size, partial=partial) | |
) | |
if repeat > 1: | |
dataset = dataset.repeat(repeat) | |
loader = wds.WebLoader(dataset, num_workers=num_workers, batch_size=None) | |
return loader | |
def train( | |
shards: str = "pipe:gsutil cat gs://lpr-simsplit/split-{000000..00009}.tar", | |
size: int = 3, | |
batch_size: int = 10, | |
nworkers: int = 3, | |
nepochs: int = 1, | |
nbatches: int = 999999, | |
partial: bool = False, | |
showopen: bool = False, | |
showsplits: bool = False, | |
repeat: int = 1, | |
dsrepeat: int = 1, | |
): | |
"""Simulate distributed training. | |
This will perform dataset loading for each worker in a distributed training | |
job of size `size` and report the number of batches and samples returned by | |
each worker. | |
For distributed SGD (DistributedDataParallel) to work, each worker needs to return | |
exactly the same number of batches. To get exact epochs, you need to ensure that | |
all the shards have exactly the same number of samples and that the number of shards | |
is divisible by (#workers * #nodes). | |
If your data isn't in that form (and it usually isn't), you have to do something different. | |
""" | |
global dist_size, dist_rank, show_splits | |
show_splits = showsplits | |
if showopen: | |
os.environ["GOPEN_VERBOSE"] = "1" | |
dist_size = size | |
loader = make_loader( | |
shards, | |
batch_size=batch_size, | |
num_workers=nworkers, | |
partial=partial, | |
repeat=dsrepeat, | |
) | |
if repeat > 1: | |
loader = loader.repeat(nepochs=repeat) | |
for rank in range(size): | |
dist_rank = rank | |
batches = 0 | |
total = 0 | |
for inputs, targets in islice(loader, 0, nbatches): | |
batches += 1 | |
total += len(inputs) | |
err = " TOO FEW BATCHES" if nbatches < 999999 and batches < nbatches else "" | |
print(f"=== rank {dist_rank} batches {batches} total {total}{err}") | |
if __name__ == "__main__": | |
typer.run(train) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment