Skip to content

Instantly share code, notes, and snippets.

@abcdabcd987
Created February 23, 2025 22:50
Show Gist options
  • Save abcdabcd987/094b2b8c4015da420ae37655d04431f0 to your computer and use it in GitHub Desktop.
Save abcdabcd987/094b2b8c4015da420ae37655d04431f0 to your computer and use it in GitHub Desktop.
# PyTorch 2.6 + Cuda 12.6 Segmentation Fault
# Good combinations:
# --no-early-bind --sync=no
# --no-early-bind --sync=sleep
# --no-early-bind --sync=barrier
# Bad combinations:
# --early-bind --sync=no
# Good on NCCL 2.21.5 (offical PyTorch wheel) but segfault on NCCL 2.25.1 (custom built):
# --early-bind --sync=sleep
# --early-bind --sync=barrier
import argparse
import time
import torch
from torch.distributed import GroupMember, ProcessGroup
def all_gather(pg: ProcessGroup, x: torch.Tensor) -> torch.Tensor:
output_tensor = torch.empty((pg.size(),) + x.size(), dtype=x.dtype, device=x.device)
torch.distributed.all_gather_into_tensor(output_tensor, x, group=pg)
return output_tensor
def worker_main(rank: int, world_size: int, early_bind: bool, sync: str) -> None:
assert not torch.distributed.is_initialized()
device = torch.device(rank)
torch.cuda.set_device(device)
torch.distributed.init_process_group(
backend="gloo",
init_method="tcp://localhost:29500",
world_size=world_size,
rank=rank,
device_id=device if early_bind else None,
)
print(f"[rank={rank}] Initialized global process group.")
global_group = torch.distributed.new_group(ranks=[0, 1, 2, 3], backend="nccl")
assert isinstance(global_group, ProcessGroup)
print(f"[rank={rank}] Initialized global group.")
x = torch.tensor([rank * 100], dtype=torch.float32, device=device)
global_result = torch.tensor(
[[0], [100], [200], [300]], dtype=torch.float32, device=device
)
torch.testing.assert_close(all_gather(global_group, x), global_result)
print(f"[rank={rank}] all_gather done")
sub_group0 = torch.distributed.new_group(ranks=[0, 1], backend="nccl")
sub_group1 = torch.distributed.new_group(ranks=[2, 3], backend="nccl")
if rank in [0, 1]:
assert isinstance(sub_group0, ProcessGroup)
assert sub_group1 == GroupMember.NON_GROUP_MEMBER
subgroup = sub_group0
local_result = torch.tensor([[0], [100]], dtype=torch.float32, device=device)
else:
assert sub_group0 == GroupMember.NON_GROUP_MEMBER
assert isinstance(sub_group1, ProcessGroup)
subgroup = sub_group1
local_result = torch.tensor([[200], [300]], dtype=torch.float32, device=device)
print(f"[rank={rank}] Initialized subgroup.")
match sync:
case "no":
print(f"[rank={rank}] no sync")
case "sleep":
time.sleep(10)
print(f"[rank={rank}] sleep done")
case "barrier":
torch.distributed.barrier(group=subgroup, device_ids=[rank])
print(f"[rank={rank}] subgroup barrier done")
torch.testing.assert_close(all_gather(global_group, x), global_result)
print(f"[rank={rank}] global all_gather done")
torch.testing.assert_close(all_gather(subgroup, x), local_result)
print(f"[rank={rank}] local all_gather done")
torch.distributed.destroy_process_group(subgroup)
torch.distributed.destroy_process_group(global_group)
torch.distributed.destroy_process_group()
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--early-bind",
action=argparse.BooleanOptionalAction,
type=bool,
default=True,
)
parser.add_argument(
"--sync",
choices=["no", "sleep", "barrier"],
default="no",
)
args = parser.parse_args()
world_size = 4
torch.multiprocessing.spawn(
worker_main,
args=(world_size, args.early_bind, args.sync),
nprocs=world_size,
join=True,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment