Last active
November 28, 2024 12:46
-
-
Save spezold/20e12871e29e1a03ff1a6bfcdcaac38a to your computer and use it in GitHub Desktop.
Demonstrate how a resource can be shared for exclusive access among the workers of a PyTorch dataloader, by distributing a corresponding lock.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Demonstrate how a resource can be shared for exclusive access among the workers of a PyTorch dataloader, by distributing | |
a corresponding lock. | |
At first glance, it might seem that it defies the purpose of using multiple workers by sharing a lock among them. There | |
are some use cases though; for example, (1) the access to the shared resource (and thus the locking) is short and | |
followed by further processing within the worker or (2) there are multiple shared resources and one simply has to ensure | |
that no concurrent access to the same resource happens (in the latter case, multiple locks would need to be shared in | |
the demonstrated way). | |
The code below should print something like: | |
GPU 2, worker 1: 948.90s – 949.90s | |
GPU 2, worker 0: 955.91s – 956.91s | |
GPU 2, worker 2: 952.91s – 953.91s | |
GPU 0, worker 0: 949.91s – 950.91s | |
GPU 0, worker 2: 954.91s – 955.91s | |
GPU 0, worker 1: 956.91s – 957.91s | |
GPU 3, worker 2: 947.90s – 948.90s | |
GPU 3, worker 0: 950.91s – 951.91s | |
GPU 3, worker 1: 957.91s – 958.91s | |
GPU 1, worker 1: 958.91s – 959.91s | |
GPU 1, worker 0: 951.91s – 952.91s | |
GPU 1, worker 2: 953.91s – 954.91s | |
Took 17.87s. Check above: | |
- Exactly `num_gpus * num_workers_per_gpu` outputs altogether? | |
- Exactly one output for each combination of GPU and worker? | |
- No overlapping time intervals? | |
""" | |
import multiprocessing as mp | |
import time | |
from torch.utils.data import Dataset, DataLoader, get_worker_info | |
class LockedDataset(Dataset): | |
def __init__(self, gpu_id, lock, length): | |
self._gpu_id = gpu_id | |
self._lock = lock | |
self._length = length | |
self._worker_id = -1 | |
def __getitem__(self, idx): | |
# Acquire the lock for getting the sample (the content of the sample doesn't really matter here) | |
with self._lock: | |
item = idx | |
from_t = time.time() | |
time.sleep(1.) | |
to_t = time.time() | |
print(f"GPU {self._gpu_id}, worker {self._worker_id}: {from_t % 1000 :06.2f}s – {to_t % 1000 :06.2f}s") | |
return item | |
def __len__(self): | |
return self._length | |
# This is just for printing the worker's ID and *not* necessary for the actual synchronization setup to work | |
def worker_init_fn(*args): | |
info = get_worker_info() | |
info.dataset._worker_id = info.id | |
def main(gpu_id, shared_lock, num_workers): | |
# Provide the lock to the dataset. The DataLoader will then distribute it among its worker processes. | |
# Provide one sample per worker (`length=num_workers`). Set `batch_size=1` to encourage each sample being loaded | |
# by a different worker (not sure if guaranteed though). Finally, actually load the samples. | |
dataset = LockedDataset(gpu_id, shared_lock, length=num_workers) | |
dataloader = DataLoader(dataset, batch_size=1, num_workers=num_workers, worker_init_fn=worker_init_fn) | |
for batch in dataloader: | |
pass # We don't have any use for the batches other than loading them | |
if __name__ == "__main__": | |
num_gpus = 4 | |
num_workers_per_gpu = 3 | |
try: | |
mp.set_start_method("spawn") | |
except RuntimeError as e: | |
# For convenience: "spawn" may have been set already with an earlier run (happens e.g. in Spyder IDE) | |
if not "already been set" in str(e): | |
raise | |
start = time.time() | |
shared_lock = mp.Lock() | |
jobs = [] | |
for gpu_id in range(num_gpus): | |
# Launch one process per GPU | |
jobs.append(job := mp.Process(target=main, args=(gpu_id, shared_lock, num_workers_per_gpu))) | |
job.start() | |
for job in jobs: | |
job.join() | |
print(f"Took {time.time() - start:.2f}s. Check above:") | |
print(" - Exactly `num_gpus * num_workers_per_gpu` outputs altogether?") | |
print(" - Exactly one output for each combination of GPU and worker?") | |
print(" - No overlapping time intervals?") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment