Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created January 3, 2025 03:05
Show Gist options
  • Save youkaichao/ee42782c91e444851965509236beea62 to your computer and use it in GitHub Desktop.
Save youkaichao/ee42782c91e444851965509236beea62 to your computer and use it in GitHub Desktop.
cmp shm broadcast and pytorch broadcast object list
import torch.distributed as dist
import torch
import time
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
torch.cuda.set_device(rank)
N_warmup = 10
N_measure = 100
if rank == 0:
rpc_method = "a"
for i in range(N_warmup):
dist.broadcast_object_list([rpc_method], src=0)
send_latencies = []
for i in range(N_measure):
t0 = time.perf_counter_ns()
dist.broadcast_object_list([rpc_method], src=0)
t1 = time.perf_counter_ns()
send_latencies.append(t1 - t0)
median = sorted(send_latencies)[len(send_latencies) // 2]
print(f"Median latency: {median / 1e3:.3f} us")
else:
for i in range(N_warmup + N_measure):
recv_data = [None]
dist.broadcast_object_list(recv_data, src=0)
rpc_method = recv_data[0]
assert rpc_method == "a"
# torchrun --nproc-per-node=8 test_pytorch.py
# Median latency: 121.143 us
@youkaichao
Copy link
Author

import torch.distributed as dist
import torch
import time

dist.init_process_group(backend="gloo")

rank = dist.get_rank()

from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
mq = MessageQueue.create_from_process_group(torch.distributed.distributed_c10d._get_default_group(), 1 << 20, 200, 0)

N_warmup = 10
N_measure = 100

if rank == 0:
    rpc_method = "a"
    for i in range(N_warmup):
        mq.enqueue([rpc_method])
    send_latencies = []
    for i in range(N_measure):
        t0 = time.perf_counter_ns()
        mq.enqueue([rpc_method])
        t1 = time.perf_counter_ns()
        send_latencies.append(t1 - t0)
    median = sorted(send_latencies)[len(send_latencies) // 2]
    print(f"Median latency: {median / 1e3:.3f} us")
else:
    for i in range(N_warmup + N_measure):
        recv_data = mq.dequeue()
        rpc_method = recv_data[0]
        assert rpc_method == "a"

# save as test_shm_broadcast.py
# pip install -U vllm
# torchrun --nproc-per-node=8 test_shm_broadcast.py
# Median latency: 10.678 us

@youkaichao
Copy link
Author

measure both pytorch send and recv latency:

import torch.distributed as dist
import torch
import time

dist.init_process_group(backend="nccl")

rank = dist.get_rank()
torch.cuda.set_device(rank)

N_warmup = 10
N_measure = 100

if rank == 0:
    rpc_method = "a"
    for i in range(N_warmup):
        dist.broadcast_object_list([rpc_method, 0], src=0)
    dist.barrier()
    send_latencies = []
    for i in range(N_measure):
        t0 = time.perf_counter_ns()
        dist.broadcast_object_list([rpc_method, t0], src=0)
        t1 = time.perf_counter_ns()
        send_latencies.append(t1 - t0)
        dist.barrier()
    median = sorted(send_latencies)[len(send_latencies) // 2]
    print(f"Median send latency: {median / 1e3:.3f} us")
else:
    for i in range(N_warmup):
        dist.broadcast_object_list([None, None], src=0)
    dist.barrier()
    recv_latencies = []
    for i in range(N_measure):
        recv_data = [None, None]
        dist.broadcast_object_list(recv_data, src=0)
        recv_time = time.perf_counter_ns()
        send_time = recv_data[1]
        rpc_method = recv_data[0]
        assert rpc_method == "a"
        recv_latencies.append(recv_time - send_time)
        dist.barrier()
    median = sorted(recv_latencies)[len(recv_latencies) // 2]
    print(f"Median recv latency for rank {dist.get_rank()}: {median / 1e3:.3f} us")

# torchrun --nproc-per-node=8 test_pytorch.py
# Median send latency: 121.798 us
# Median recv latency for rank 6: 243.355 us
# Median recv latency for rank 3: 243.287 us
# Median recv latency for rank 7: 256.957 us
# Median recv latency for rank 5: 250.861 us
# Median recv latency for rank 4: 252.541 us
# Median recv latency for rank 1: 244.200 us
# Median recv latency for rank 2: 241.544 us

@youkaichao
Copy link
Author

measure both shm broadcast send and recv latency:

import torch.distributed as dist
import torch
import time

dist.init_process_group(backend="gloo")

rank = dist.get_rank()
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
mq = MessageQueue.create_from_process_group(torch.distributed.distributed_c10d._get_default_group(), 1 << 20, 200, 0)

N_warmup = 10
N_measure = 100

if rank == 0:
    rpc_method = "a"
    for i in range(N_warmup):
        mq.enqueue([rpc_method, 0])
    dist.barrier()
    send_latencies = []
    for i in range(N_measure):
        t0 = time.perf_counter_ns()
        mq.enqueue([rpc_method, t0])
        t1 = time.perf_counter_ns()
        send_latencies.append(t1 - t0)
        dist.barrier()
    median = sorted(send_latencies)[len(send_latencies) // 2]
    print(f"Median send latency: {median / 1e3:.3f} us")
else:
    for i in range(N_warmup):
        recv_data = mq.dequeue()
    dist.barrier()
    recv_latencies = []
    for i in range(N_measure):
        recv_data = mq.dequeue()
        recv_time = time.perf_counter_ns()
        send_time = recv_data[1]
        rpc_method = recv_data[0]
        assert rpc_method == "a"
        recv_latencies.append(recv_time - send_time)
        dist.barrier()
    median = sorted(recv_latencies)[len(recv_latencies) // 2]
    print(f"Median recv latency for rank {dist.get_rank()}: {median / 1e3:.3f} us")

# save as test_shm_broadcast.py
# pip install -U vllm
# torchrun --nproc-per-node=8 test_shm_broadcast.py
# Median send latency: 10.763 us
# Median recv latency for rank 7: 20.339 us
# Median recv latency for rank 5: 19.841 us
# Median recv latency for rank 3: 20.568 us
# Median recv latency for rank 4: 20.232 us
# Median recv latency for rank 1: 20.146 us
# Median recv latency for rank 6: 20.858 us
# Median recv latency for rank 2: 21.058 us

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment