Created
January 15, 2021 12:05
-
-
Save harpone/3b6003c22295a50cbd3d2cfc566dc115 to your computer and use it in GitHub Desktop.
Test Webdataset with torch-xla multiprocessing distributed setting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import islice | |
import os | |
import torch | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
import numpy as np | |
import torch_xla.distributed.parallel_loader as pl | |
import torch_xla.core.xla_model as xm | |
import torch_xla.distributed.xla_multiprocessing as xmp | |
import webdataset as wds | |
""" | |
This is a toy implementation of webdataset for torch-xla. | |
I was suspecting that the pl.MpDeviceLoader was somehow not splitting examples across cores and workers, but | |
it seems to be fine (test by setting `split_per_core` and `split_per_worker` to True/False). | |
""" | |
batch_size = 5 | |
num_iters = 64 | |
dataset_length = 8 * 100 | |
tpuip = '10.44.70.138' # set this to your TPU's IP | |
num_cores = 8 | |
num_workers = 4 | |
split_per_core = True | |
split_per_worker = True | |
# This is just the webdataset default OpenImages dataset: | |
urls = "http://storage.googleapis.com/nvdata-openimages/openimages-train-{000000..000554}.tar" | |
urls = f"pipe:curl -L -s {urls} || true" | |
# Setting up: | |
os.environ['OMP_NUM_THREADS'] = '1' # good to have for webdataset | |
os.environ["XRT_TPU_CONFIG"] = f"tpu_worker;0;{tpuip}:8470" | |
def identity(x): | |
return x | |
def main(device_idx): | |
print(f'Starting process {device_idx}.') | |
device = xm.xla_device() | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
preproc = transforms.Compose([ | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
dataset = ( | |
wds.Dataset(urls, length=dataset_length) | |
.decode("pil") | |
.to_tuple("jpg;png", "json") | |
.map_tuple(preproc, identity) | |
.batched(batch_size) | |
) | |
def shard_selection(urls_): | |
"""Split urls correctly per accelerator. | |
:param urls_: | |
:return: slice of urls_ | |
""" | |
urls_this = urls_[device_idx::num_cores] | |
return urls_this | |
def shard_shuffle(urls_): # not really a *shuffle*... | |
"""Split urls correctly per worker. | |
:param urls_: | |
:return: slice of urls_ | |
""" | |
worker_info = torch.utils.data.get_worker_info() | |
if worker_info is None: | |
return urls_ | |
else: | |
num_workers_ = worker_info.num_workers | |
worker_id = worker_info.id | |
urls_ = urls_[worker_id::num_workers_] | |
return urls_ | |
if split_per_core: | |
dataset.shard_selection = shard_selection | |
if split_per_worker: | |
dataset.shard_shuffle = shard_shuffle | |
dataloader = DataLoader(dataset, | |
num_workers=num_workers, | |
batch_size=None, | |
collate_fn=None) # batching, collate done in dataset hence the None:s | |
device_dataloader = pl.MpDeviceLoader(dataloader, device) | |
xm.rendezvous('init') # wait until all processes started | |
ids = list() | |
for i, sample in enumerate(islice(device_dataloader, 0, num_iters)): | |
print(f'device={device_idx} :: iter={i}') | |
ids_this = list() | |
for example in sample[1]: | |
try: | |
id = example[0]['ImageID'] | |
except KeyError: | |
id = 'None' | |
ids_this.append(id) | |
ids += ids_this | |
print(f'ids={ids_this}') | |
xm.rendezvous('step') | |
xm.rendezvous('all_collected') | |
print(f'Process {device_idx} done.') | |
# Count unique image ids: | |
ids_int = list() | |
for id in ids: | |
id_int = np.frombuffer(bytes(id, encoding='utf-8'), np.uint8).astype(np.int32) # to int because I want to gather from all devices | |
ids_int.append(torch.tensor(id_int)) | |
ids_int = torch.stack(ids_int).to(device) # shape [batch_size * num_iters, 16] | |
# Check that all ids are unique: | |
xm.rendezvous('before_gather') | |
ids_int_all = xm.all_gather(ids_int, 0) | |
xm.rendezvous('after_gather') | |
ids_unique = torch.unique(ids_int_all, dim=0) | |
if len(ids_unique) == len(ids_int_all): | |
xm.master_print('All examples unique.') | |
else: | |
xm.master_print('Non-unique examples detected:') | |
xm.master_print(f'Num examples={len(ids_int_all)}') | |
xm.master_print(f'Num uniques={len(ids_unique)}') | |
raise KeyboardInterrupt | |
if __name__ == '__main__': | |
xmp.spawn(main, args=(), nprocs=num_cores, start_method='fork') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment