Last active
April 14, 2023 08:48
-
-
Save vfdev-5/c32a7744bc371ce9fcfab1c5ff7a934c to your computer and use it in GitHub Desktop.
PyTorch Distributed playground
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Run it | |
# torchrun --nproc_per_node=4 example_1.py | |
import os | |
import time | |
import torch | |
import torch.distributed as dist | |
def pprint(rank, msg): | |
# We add sleep to avoid printing clutter | |
time.sleep(0.5 * rank) | |
print(rank, msg) | |
if __name__ == "__main__": | |
# See https://pytorch.org/docs/stable/distributed.html#which-backend-to-use | |
backend = "gloo" | |
device = "cpu" | |
dist.init_process_group(backend) | |
if backend == "nccl": | |
device = "cuda" | |
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | |
rank = dist.get_rank() | |
ws = dist.get_world_size() | |
pprint(rank, f"Hello from process {rank} among {ws} others") | |
pprint(rank, f"Group type: {type(dist.get_backend())} : {dist.get_backend()}") | |
# Let's synchronize here | |
dist.barrier() | |
# Let's compute something: | |
# all_reduce: see https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce | |
# sum all data per proc | |
data_per_proc = torch.tensor([2.0 * rank + 0.5], device=device) | |
pprint(rank, f"data_per_proc = {data_per_proc}") | |
dist.all_reduce(data_per_proc) | |
pprint(rank, f"Now data_per_proc = {data_per_proc}") | |
# Let's synchronize here | |
dist.barrier() | |
# Let's broadcast data from proc 0 to everyone in the world | |
# See: https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast | |
if rank == 0: | |
data_to_broadcast = torch.tensor([1.2345, 2.3456], device=device) | |
else: | |
data_to_broadcast = torch.tensor([0.0, 0.0], device=device) | |
pprint(rank, f"Before data is {data_to_broadcast}") | |
dist.broadcast(data_to_broadcast, src=0) | |
pprint(rank, f"Now data is {data_to_broadcast}") | |
# Let's synchronize here | |
dist.barrier() | |
# Let's use gather: | |
data_to_gather = torch.tensor([1.0 * rank, 2.0 * rank], device=device) | |
if rank == 0: | |
gather_list = [torch.empty(2) for _ in range(ws)] | |
else: | |
gather_list = None | |
pprint(rank, f"Before data is {gather_list}") | |
dist.gather(data_to_gather, gather_list=gather_list, dst=0) | |
pprint(rank, f"Now data is {gather_list}") | |
# Let's synchronize here | |
dist.barrier() | |
dist.destroy_process_group() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Run it | |
# torchrun --nproc_per_node=4 example_2.py --init_method="file://tmp/shared" | |
# or | |
# torchrun --nproc_per_node=4 example_2.py | |
import argparse | |
import os | |
import time | |
import torch | |
import torch.distributed as dist | |
def pprint(rank, msg): | |
# We add sleep to avoid printing clutter | |
time.sleep(0.5 * rank) | |
print(rank, msg) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("Torch Dist Example 2") | |
parser.add_argument("--init_method", type=str) | |
args = parser.parse_args() | |
ws = int(os.environ["WORLD_SIZE"]) if args.init_method is not None else -1 | |
r = int(os.environ["RANK"]) if args.init_method is not None else -1 | |
if args.init_method is None: | |
args.init_method = "env://" | |
# See https://pytorch.org/docs/stable/distributed.html#which-backend-to-use | |
dist.init_process_group("gloo", init_method=args.init_method, world_size=ws, rank=r) | |
rank = dist.get_rank() | |
ws = dist.get_world_size() | |
pprint(rank, f"Hello from process {rank} among {ws} others - init_method={args.init_method}") | |
pprint(rank, f"Group type: {type(dist.get_backend())} : {dist.get_backend()}") | |
# Let's synchronize here | |
dist.barrier() | |
# Let's compute something: | |
# all_reduce: see https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce | |
# sum all data per proc | |
data_per_proc = torch.tensor([2.0 * rank + 0.5]) | |
pprint(rank, f"data_per_proc = {data_per_proc}") | |
dist.all_reduce(data_per_proc) | |
pprint(rank, f"Now data_per_proc = {data_per_proc}") | |
# Let's synchronize here | |
dist.barrier() | |
dist.destroy_process_group() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Run it | |
# python -u example_3.py --init_method="file:///tmp/abc" --ws=4 | |
# or | |
# python -u example_3.py --ws=4 | |
import argparse | |
import os | |
import time | |
import torch | |
import torch.distributed as dist | |
from torch.multiprocessing import start_processes, spawn | |
def pprint(rank, msg): | |
# We add sleep to avoid printing clutter | |
time.sleep(0.5 * rank) | |
print(rank, msg) | |
def func(local_rank, args): | |
if args.init_method is None: | |
args.init_method = "env://" | |
os.environ["WORLD_SIZE"] = str(args.ws) | |
os.environ["RANK"] = str(local_rank) | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = "2233" | |
ws = -1 | |
r = -1 | |
else: | |
ws = args.ws | |
r = local_rank | |
# See https://pytorch.org/docs/stable/distributed.html#which-backend-to-use | |
dist.init_process_group("gloo", init_method=args.init_method, world_size=ws, rank=r) | |
rank = dist.get_rank() | |
ws = dist.get_world_size() | |
pprint(rank, f"Hello from process {rank} among {ws} others - init_method={args.init_method}") | |
pprint(rank, f"Group type: {type(dist.get_backend())} : {dist.get_backend()}") | |
# Let's synchronize here | |
dist.barrier() | |
# Let's compute something: | |
# all_reduce: see https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_reduce | |
# sum all data per proc | |
data_per_proc = torch.tensor([2.0 * rank + 0.5]) | |
pprint(rank, f"data_per_proc = {data_per_proc}") | |
dist.all_reduce(data_per_proc) | |
pprint(rank, f"Now data_per_proc = {data_per_proc}") | |
# Let's synchronize here | |
dist.barrier() | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("Torch Dist Example 3") | |
parser.add_argument("--init_method", type=str) | |
parser.add_argument("--ws", type=int, default=1) | |
args = parser.parse_args() | |
start_processes( | |
func, | |
args=(args, ), | |
nprocs=args.ws, | |
start_method="fork" | |
) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Run it | |
# torchrun --nproc_per_node=4 example_idist_1.py | |
import os | |
import time | |
import torch | |
import ignite.distributed as idist | |
def pprint(rank, msg): | |
# We add sleep to avoid printing clutter | |
time.sleep(0.5 * rank) | |
print(rank, msg) | |
if __name__ == "__main__": | |
# https://pytorch.org/ignite/distributed.html#ignite.distributed.utils.initialize | |
idist.initialize("gloo") | |
rank = idist.get_rank() | |
ws = idist.get_world_size() | |
pprint(rank, f"Hello from process {rank} among {ws} others") | |
pprint(rank, f"Group type: {type(idist.backend())} : {idist.backend()}, {idist.utils._model._backend}") | |
# Let's synchronize here | |
idist.barrier() | |
# Let's compute something: | |
# sum all data per proc | |
data_per_proc = torch.tensor([2.0 * rank + 0.5]) | |
pprint(rank, f"data_per_proc = {data_per_proc}") | |
idist.all_reduce(data_per_proc) | |
pprint(rank, f"Now data_per_proc = {data_per_proc}") | |
# Let's synchronize here | |
idist.barrier() | |
# Let's broadcast data from proc 0 to everyone in the world | |
# See: https://pytorch.org/ignite/distributed.html#ignite.distributed.utils.broadcast | |
if rank == 0: | |
data_to_broadcast = torch.tensor([1.2345, 2.3456]) | |
else: | |
data_to_broadcast = torch.tensor([0.0, 0.0]) | |
pprint(rank, f"Before data is {data_to_broadcast}") | |
data_to_broadcast = idist.broadcast(data_to_broadcast, src=0) | |
pprint(rank, f"Now data is {data_to_broadcast}") | |
# Let's synchronize here | |
idist.barrier() | |
idist.finalize() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment