Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Last active April 14, 2023 08:48
Show Gist options
  • Save vfdev-5/c32a7744bc371ce9fcfab1c5ff7a934c to your computer and use it in GitHub Desktop.
Save vfdev-5/c32a7744bc371ce9fcfab1c5ff7a934c to your computer and use it in GitHub Desktop.
PyTorch Distributed playground
# Run it
# torchrun --nproc_per_node=4 example_1.py
import os
import time
import torch
import torch.distributed as dist
def pprint(rank, msg):
# We add sleep to avoid printing clutter
time.sleep(0.5 * rank)
print(rank, msg)
if __name__ == "__main__":
# See https://pytorch.org/docs/stable/distributed.html#which-backend-to-use
backend = "gloo"
device = "cpu"
dist.init_process_group(backend)
if backend == "nccl":
device = "cuda"
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
rank = dist.get_rank()
ws = dist.get_world_size()
pprint(rank, f"Hello from process {rank} among {ws} others")
pprint(rank, f"Group type: {type(dist.get_backend())} : {dist.get_backend()}")
# Let's synchronize here
dist.barrier()
# Let's compute something:
# all_reduce: see https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce
# sum all data per proc
data_per_proc = torch.tensor([2.0 * rank + 0.5], device=device)
pprint(rank, f"data_per_proc = {data_per_proc}")
dist.all_reduce(data_per_proc)
pprint(rank, f"Now data_per_proc = {data_per_proc}")
# Let's synchronize here
dist.barrier()
# Let's broadcast data from proc 0 to everyone in the world
# See: https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast
if rank == 0:
data_to_broadcast = torch.tensor([1.2345, 2.3456], device=device)
else:
data_to_broadcast = torch.tensor([0.0, 0.0], device=device)
pprint(rank, f"Before data is {data_to_broadcast}")
dist.broadcast(data_to_broadcast, src=0)
pprint(rank, f"Now data is {data_to_broadcast}")
# Let's synchronize here
dist.barrier()
# Let's use gather:
data_to_gather = torch.tensor([1.0 * rank, 2.0 * rank], device=device)
if rank == 0:
gather_list = [torch.empty(2) for _ in range(ws)]
else:
gather_list = None
pprint(rank, f"Before data is {gather_list}")
dist.gather(data_to_gather, gather_list=gather_list, dst=0)
pprint(rank, f"Now data is {gather_list}")
# Let's synchronize here
dist.barrier()
dist.destroy_process_group()
# Run it
# torchrun --nproc_per_node=4 example_2.py --init_method="file://tmp/shared"
# or
# torchrun --nproc_per_node=4 example_2.py
import argparse
import os
import time
import torch
import torch.distributed as dist
def pprint(rank, msg):
# We add sleep to avoid printing clutter
time.sleep(0.5 * rank)
print(rank, msg)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Torch Dist Example 2")
parser.add_argument("--init_method", type=str)
args = parser.parse_args()
ws = int(os.environ["WORLD_SIZE"]) if args.init_method is not None else -1
r = int(os.environ["RANK"]) if args.init_method is not None else -1
if args.init_method is None:
args.init_method = "env://"
# See https://pytorch.org/docs/stable/distributed.html#which-backend-to-use
dist.init_process_group("gloo", init_method=args.init_method, world_size=ws, rank=r)
rank = dist.get_rank()
ws = dist.get_world_size()
pprint(rank, f"Hello from process {rank} among {ws} others - init_method={args.init_method}")
pprint(rank, f"Group type: {type(dist.get_backend())} : {dist.get_backend()}")
# Let's synchronize here
dist.barrier()
# Let's compute something:
# all_reduce: see https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce
# sum all data per proc
data_per_proc = torch.tensor([2.0 * rank + 0.5])
pprint(rank, f"data_per_proc = {data_per_proc}")
dist.all_reduce(data_per_proc)
pprint(rank, f"Now data_per_proc = {data_per_proc}")
# Let's synchronize here
dist.barrier()
dist.destroy_process_group()
# Run it
# python -u example_3.py --init_method="file:///tmp/abc" --ws=4
# or
# python -u example_3.py --ws=4
import argparse
import os
import time
import torch
import torch.distributed as dist
from torch.multiprocessing import start_processes, spawn
def pprint(rank, msg):
# We add sleep to avoid printing clutter
time.sleep(0.5 * rank)
print(rank, msg)
def func(local_rank, args):
if args.init_method is None:
args.init_method = "env://"
os.environ["WORLD_SIZE"] = str(args.ws)
os.environ["RANK"] = str(local_rank)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "2233"
ws = -1
r = -1
else:
ws = args.ws
r = local_rank
# See https://pytorch.org/docs/stable/distributed.html#which-backend-to-use
dist.init_process_group("gloo", init_method=args.init_method, world_size=ws, rank=r)
rank = dist.get_rank()
ws = dist.get_world_size()
pprint(rank, f"Hello from process {rank} among {ws} others - init_method={args.init_method}")
pprint(rank, f"Group type: {type(dist.get_backend())} : {dist.get_backend()}")
# Let's synchronize here
dist.barrier()
# Let's compute something:
# all_reduce: see https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce
# sum all data per proc
data_per_proc = torch.tensor([2.0 * rank + 0.5])
pprint(rank, f"data_per_proc = {data_per_proc}")
dist.all_reduce(data_per_proc)
pprint(rank, f"Now data_per_proc = {data_per_proc}")
# Let's synchronize here
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Torch Dist Example 3")
parser.add_argument("--init_method", type=str)
parser.add_argument("--ws", type=int, default=1)
args = parser.parse_args()
start_processes(
func,
args=(args, ),
nprocs=args.ws,
start_method="fork"
)
# Run it
# torchrun --nproc_per_node=4 example_idist_1.py
import os
import time
import torch
import ignite.distributed as idist
def pprint(rank, msg):
# We add sleep to avoid printing clutter
time.sleep(0.5 * rank)
print(rank, msg)
if __name__ == "__main__":
# https://pytorch.org/ignite/distributed.html#ignite.distributed.utils.initialize
idist.initialize("gloo")
rank = idist.get_rank()
ws = idist.get_world_size()
pprint(rank, f"Hello from process {rank} among {ws} others")
pprint(rank, f"Group type: {type(idist.backend())} : {idist.backend()}, {idist.utils._model._backend}")
# Let's synchronize here
idist.barrier()
# Let's compute something:
# sum all data per proc
data_per_proc = torch.tensor([2.0 * rank + 0.5])
pprint(rank, f"data_per_proc = {data_per_proc}")
idist.all_reduce(data_per_proc)
pprint(rank, f"Now data_per_proc = {data_per_proc}")
# Let's synchronize here
idist.barrier()
# Let's broadcast data from proc 0 to everyone in the world
# See: https://pytorch.org/ignite/distributed.html#ignite.distributed.utils.broadcast
if rank == 0:
data_to_broadcast = torch.tensor([1.2345, 2.3456])
else:
data_to_broadcast = torch.tensor([0.0, 0.0])
pprint(rank, f"Before data is {data_to_broadcast}")
data_to_broadcast = idist.broadcast(data_to_broadcast, src=0)
pprint(rank, f"Now data is {data_to_broadcast}")
# Let's synchronize here
idist.barrier()
idist.finalize()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment