Created
February 23, 2025 22:50
-
-
Save abcdabcd987/094b2b8c4015da420ae37655d04431f0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# PyTorch 2.6 + Cuda 12.6 Segmentation Fault | |
# Good combinations: | |
# --no-early-bind --sync=no | |
# --no-early-bind --sync=sleep | |
# --no-early-bind --sync=barrier | |
# Bad combinations: | |
# --early-bind --sync=no | |
# Good on NCCL 2.21.5 (offical PyTorch wheel) but segfault on NCCL 2.25.1 (custom built): | |
# --early-bind --sync=sleep | |
# --early-bind --sync=barrier | |
import argparse | |
import time | |
import torch | |
from torch.distributed import GroupMember, ProcessGroup | |
def all_gather(pg: ProcessGroup, x: torch.Tensor) -> torch.Tensor: | |
output_tensor = torch.empty((pg.size(),) + x.size(), dtype=x.dtype, device=x.device) | |
torch.distributed.all_gather_into_tensor(output_tensor, x, group=pg) | |
return output_tensor | |
def worker_main(rank: int, world_size: int, early_bind: bool, sync: str) -> None: | |
assert not torch.distributed.is_initialized() | |
device = torch.device(rank) | |
torch.cuda.set_device(device) | |
torch.distributed.init_process_group( | |
backend="gloo", | |
init_method="tcp://localhost:29500", | |
world_size=world_size, | |
rank=rank, | |
device_id=device if early_bind else None, | |
) | |
print(f"[rank={rank}] Initialized global process group.") | |
global_group = torch.distributed.new_group(ranks=[0, 1, 2, 3], backend="nccl") | |
assert isinstance(global_group, ProcessGroup) | |
print(f"[rank={rank}] Initialized global group.") | |
x = torch.tensor([rank * 100], dtype=torch.float32, device=device) | |
global_result = torch.tensor( | |
[[0], [100], [200], [300]], dtype=torch.float32, device=device | |
) | |
torch.testing.assert_close(all_gather(global_group, x), global_result) | |
print(f"[rank={rank}] all_gather done") | |
sub_group0 = torch.distributed.new_group(ranks=[0, 1], backend="nccl") | |
sub_group1 = torch.distributed.new_group(ranks=[2, 3], backend="nccl") | |
if rank in [0, 1]: | |
assert isinstance(sub_group0, ProcessGroup) | |
assert sub_group1 == GroupMember.NON_GROUP_MEMBER | |
subgroup = sub_group0 | |
local_result = torch.tensor([[0], [100]], dtype=torch.float32, device=device) | |
else: | |
assert sub_group0 == GroupMember.NON_GROUP_MEMBER | |
assert isinstance(sub_group1, ProcessGroup) | |
subgroup = sub_group1 | |
local_result = torch.tensor([[200], [300]], dtype=torch.float32, device=device) | |
print(f"[rank={rank}] Initialized subgroup.") | |
match sync: | |
case "no": | |
print(f"[rank={rank}] no sync") | |
case "sleep": | |
time.sleep(10) | |
print(f"[rank={rank}] sleep done") | |
case "barrier": | |
torch.distributed.barrier(group=subgroup, device_ids=[rank]) | |
print(f"[rank={rank}] subgroup barrier done") | |
torch.testing.assert_close(all_gather(global_group, x), global_result) | |
print(f"[rank={rank}] global all_gather done") | |
torch.testing.assert_close(all_gather(subgroup, x), local_result) | |
print(f"[rank={rank}] local all_gather done") | |
torch.distributed.destroy_process_group(subgroup) | |
torch.distributed.destroy_process_group(global_group) | |
torch.distributed.destroy_process_group() | |
def main() -> None: | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--early-bind", | |
action=argparse.BooleanOptionalAction, | |
type=bool, | |
default=True, | |
) | |
parser.add_argument( | |
"--sync", | |
choices=["no", "sleep", "barrier"], | |
default="no", | |
) | |
args = parser.parse_args() | |
world_size = 4 | |
torch.multiprocessing.spawn( | |
worker_main, | |
args=(world_size, args.early_bind, args.sync), | |
nprocs=world_size, | |
join=True, | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment