Last active
January 20, 2025 09:03
-
-
Save scturtle/c98a7b7a5fa4499a08f1d4bff0cb9daf to your computer and use it in GitHub Desktop.
nccl in 500 LOCs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# https://github.com/FateScript/experiments/blob/main/se/mpi/mpi.py | |
# https://github.com/facebookincubator/gloo/tree/main/gloo | |
import math | |
import multiprocessing | |
import os | |
import numpy as np | |
_P2P_PIPES = None | |
def make_p2p_pipe(num_processes: int): | |
pipe_pairs = {} | |
for i in range(num_processes): | |
for j in range(i + 1, num_processes): | |
(src, dst) = multiprocessing.Pipe() # bidirectional | |
pipe_pairs[(i, j)] = src # recv | |
pipe_pairs[(j, i)] = dst # send | |
return pipe_pairs | |
def set_pipes(pipes): | |
global _P2P_PIPES | |
_P2P_PIPES = pipes | |
def get_pipes(): | |
global _P2P_PIPES | |
return _P2P_PIPES | |
def set_rank(rank): | |
os.environ["RANK"] = str(rank) | |
def get_rank() -> int: | |
return int(os.environ["RANK"]) | |
def set_world_size(world_size): | |
os.environ["WORLD_SIZE"] = str(world_size) | |
def get_world_size() -> int: | |
return int(os.environ["WORLD_SIZE"]) | |
def log_rank(*args): | |
cur_rank = get_rank() | |
if cur_rank == 0: | |
print(*args) | |
def init_env(rank, world_size, shared_mem, pipe_pairs): | |
# similar to context in gloo | |
set_rank(rank) | |
set_world_size(world_size) | |
set_pipes(pipe_pairs) | |
def barrier(): | |
# dissemination barrier algorithm | |
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/barrier.cc#L18 | |
rank = get_rank() | |
world_size = get_world_size() | |
pipes = get_pipes() | |
dist = 1 | |
while dist < world_size: | |
to_rank = (rank + dist) % world_size | |
pipes[(rank, to_rank)].send(True) | |
from_rank = (rank - dist) % world_size | |
pipes[(rank, from_rank)].recv() | |
dist <<= 1 | |
def gather(data, target_rank: int = 0): | |
# Collects data from multiple ranks into a single rank. | |
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/gather.cc#L18 | |
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes() | |
if rank != target_rank: | |
pipes[(rank, target_rank)].send(data) | |
return None | |
else: | |
gathered_data = [] | |
for src_rank in range(world_size): | |
if src_rank == rank: | |
gathered_data.append(data) | |
continue | |
recv_data = pipes[(rank, src_rank)].recv() | |
gathered_data.append(recv_data) | |
return concat_data(gathered_data, axis=-1) | |
def scatter(data, source_rank: int = 0): | |
# Distributes chunks of data from one rank to multiple ranks. | |
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/scatter.cc#L19 | |
rank = get_rank() | |
pipe_pairs = get_pipes() | |
if rank == source_rank: | |
world_size = get_world_size() | |
assert len(data) % world_size == 0 | |
scatter_data = [data[i::world_size] for i in range(world_size)] | |
for dst_rank, dst_data in enumerate(scatter_data): | |
if dst_rank == source_rank: | |
data = dst_data | |
continue | |
pipe_pairs[(source_rank, dst_rank)].send(dst_data) | |
return data | |
else: | |
data = pipe_pairs[(rank, source_rank)].recv() | |
return data | |
def elementwise_sum(data1, data2): | |
if isinstance(data1, np.ndarray): | |
return data1 + data2 | |
else: # iterable | |
return type(data1)([x + y for x, y in zip(data1, data2)]) | |
def elementwise_max(data1, data2): | |
if isinstance(data1, np.ndarray): | |
return np.maximum(data1, data2) | |
else: # iterable | |
return type(data1)([max(x, y) for x, y in zip(data1, data2)]) | |
def elementwise_div(data, size: int): | |
if isinstance(data, np.ndarray): | |
return data / size | |
else: # iterable | |
return type(data)([x / size for x in data]) | |
def concat_data(data_list, axis: int = -1): | |
if isinstance(data_list[0], np.ndarray): | |
return np.concatenate(data_list, axis=axis) | |
else: # iterable | |
return type(data_list[0])([x for data in data_list for x in data]) | |
OP_FUNCS = { | |
"sum": elementwise_sum, | |
"max": elementwise_max, | |
} | |
def op_to_func(op: str): | |
return OP_FUNCS[op.lower()] | |
def broadcast(data, source_rank: int = 0): | |
# Sends data from one rank to all other ranks. | |
# Iterative halving/doubling algorithm, O(log(n)) time complexity | |
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/broadcast.cc#L20 | |
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes() | |
v_rank = (rank - source_rank) % world_size | |
involved_rank = 1 | |
for turn in range(math.ceil(math.log2(world_size))): | |
# 1st turn: rank0 -> rank1 | |
# 2nd turn: rank0 -> rank2, rank1 -> rank3, and so on. | |
min_recv_rank = involved_rank # 2 ** turn | |
involved_rank <<= 1 | |
if v_rank < involved_rank: | |
is_send_rank = v_rank < min_recv_rank | |
v_peer_rank = (v_rank + min_recv_rank) % involved_rank | |
peer_rank = (v_peer_rank + source_rank) % world_size | |
if is_send_rank: | |
pipes[(rank, peer_rank)].send(data) | |
else: | |
data = pipes[(rank, peer_rank)].recv() | |
else: # not involved in this turn | |
continue | |
return data | |
def all_to_all(data, axis: int = -1): | |
# Every rank sends data to and receives data from all other ranks. | |
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/alltoall.cc#L18 | |
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes() | |
data_size = len(data) | |
assert data_size % world_size == 0, "Data size must be divisible by world_size." | |
chunk_size = data_size // world_size | |
data_list = [None for _ in range(world_size)] | |
data_list[rank] = data[rank * chunk_size: (rank + 1) * chunk_size] | |
# turn 1: 0 -> 1; 1 -> 2; 2 -> 3; 3 -> 0 | |
# turn 2: 0 -> 2; 1 -> 3; 2 -> 0; 3 -> 1 | |
# turn 3: 0 -> 3; 1 -> 0; 2 -> 0; 3 -> 2 | |
for turn in range(1, world_size): | |
send_rank = (rank + turn) % world_size | |
recv_rank = (rank - turn) % world_size | |
send_data = data[send_rank * chunk_size: (send_rank + 1) * chunk_size] | |
pipes[(rank, send_rank)].send(send_data) | |
recv_data = pipes[(rank, recv_rank)].recv() | |
data_list[recv_rank] = recv_data | |
return concat_data(data_list, axis=axis) | |
def all_gather(data, axis: int = -1): | |
# Combines data from all ranks and makes the entire result available to every rank. | |
# all_gather = gather + broadcast | |
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/allgather.cc#L19 | |
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes() | |
send_rank = (rank + 1) % world_size | |
recv_rank = (rank - 1) % world_size | |
data_list = [None for _ in range(world_size)] | |
data_list[rank] = data | |
send_idx = rank | |
# 0 -> 1; 1 -> 2; 2 -> 3; 3 -> 0 | |
# turn 0: 0=[0,*3*] 1=[*0*,1] 2=[*1*,2] 3=[*2*,3] | |
# turn 1: 0=[0,*2*,3] 1=[0,1,*3*] 2=[*0*,1,2] 3=[*1*,2,3] | |
# turn 2: 0=[0,*1*,2,3] 1=[0,1,*2*,3] 2=[0,1,2,*3*] 3=[*0*,1,2,3] | |
for _ in range(world_size - 1): | |
pipes[(rank, send_rank)].send(data_list[send_idx]) | |
recv_data = pipes[(rank, recv_rank)].recv() | |
send_idx = (send_idx - 1) % world_size | |
data_list[send_idx] = recv_data | |
return concat_data(data_list, axis=axis) | |
def reduce_scatter(data, op: str = "sum", slice: bool = False): | |
# Reduces data across ranks using a specified operation and scatters the reduced results among ranks. | |
# TODO: Halving Doubling https://github.com/facebookincubator/gloo/blob/main/gloo/reduce_scatter.h | |
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes() | |
if world_size == 1: | |
return [data] if slice else data | |
recv_rank, to_rank = (rank - 1) % world_size, (rank + 1) % world_size | |
chunk_size = math.ceil(len(data) / world_size) | |
func = elementwise_sum if op == "mean" else op_to_func(op) | |
# 0 -> 1 1 -> 2 2 -> 3 3 -> 0 | |
# recv [3_3] [0_0] [1_1] [2_2] | |
# recv [2_2 + 3_2] [0_3 + 3_3] [0_0 + 1_0] [1_1 + 2_1] | |
# recv [1_1 + 2_1 + 3_1] [0_2 + 2_2 + 3_2] [1_3 + 0_3 + 3_3] [0_0 + 1_0 + 2_0] | |
for round in range(world_size - 1): | |
send_chunk = (rank - round) % world_size | |
send_data = data[send_chunk * chunk_size: (send_chunk + 1) * chunk_size] | |
pipes[(rank, to_rank)].send(send_data) | |
recv_chunk = (send_chunk - 1) % world_size | |
recv_data = pipes[(rank, recv_rank)].recv() | |
origin_data = data[recv_chunk * chunk_size: (recv_chunk + 1) * chunk_size] | |
data[recv_chunk * chunk_size: (recv_chunk + 1) * chunk_size] = func(recv_data, origin_data) | |
if op == "mean": | |
data = elementwise_div(data, world_size) | |
if slice: | |
# recv [0_0 + 1_0 + 2_0 + 3_0] [0_1 + 1_1 + 2_1 + 3_1] [0_2 + 1_2 + 2_2 + 3_2] [0_3 + 1_3 + 2_3 + 3_3] | |
send_data = data[to_rank * chunk_size: (to_rank + 1) * chunk_size] | |
pipes[(rank, to_rank)].send(send_data) | |
return pipes[(rank, recv_rank)].recv() | |
return data | |
def all_gather_(data): | |
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes() | |
recv_rank, to_rank = (rank - 1) % world_size, (rank + 1) % world_size | |
chunk_size = math.ceil(len(data) / world_size) | |
# send [1] [2] [3] [0] | |
# send [0] [1] [2] [3] | |
# send [3] [0] [1] [2] | |
for round in range(world_size - 1): | |
send_chunk = (rank + 1 - round) % world_size | |
send_data = data[send_chunk * chunk_size: (send_chunk + 1) * chunk_size] | |
pipes[(rank, to_rank)].send(send_data) | |
recv_chunk = (send_chunk - 1) % world_size | |
recv_data = pipes[(rank, recv_rank)].recv() | |
data[recv_chunk * chunk_size: (recv_chunk + 1) * chunk_size] = recv_data | |
return data | |
def all_reduce(data, op: str = "sum"): | |
# ring all-reduce = reduce-scatter + ring all-gather | |
# code: https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/allreduce.cc#L147 | |
# There is also a 2D-ring all-reduce algorithm (Intra ring + Inter ring), like bcube algorithm. | |
# Here is the paper links: | |
# https://arxiv.org/pdf/1807.11205 | |
# https://arxiv.org/pdf/1811.05233 | |
data = reduce_scatter(data, op=op) | |
# 0 [_, 0_0+1_1+2_1+3_1, _, _] | |
# 1 [_, _, 0_2+1_2+2_2+3_2, _] | |
# 2 [_, _, _, 0_3+2_3+1_3+3_3] | |
# 3 [0_0+1_0+2_0+3_0, _, _, _] | |
data = all_gather_(data) | |
return data | |
def gather_(data, target_rank: int = 0): | |
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes() | |
chunk_size = math.ceil(len(data) / world_size) | |
chunk_idx = (rank + 1) % world_size | |
if rank != target_rank: | |
send_data = data[chunk_idx * chunk_size: (chunk_idx + 1) * chunk_size] | |
pipes[(rank, target_rank)].send(send_data) | |
return None | |
else: | |
for from_rank in range(world_size): | |
if from_rank == rank: | |
continue | |
from_chunk_idx = (from_rank + 1) % world_size | |
recv_data = pipes[(rank, from_rank)].recv() | |
data[from_chunk_idx * chunk_size: (from_chunk_idx + 1) * chunk_size] = recv_data | |
return data | |
def reduce(data, target_rank: int = 0, op: str = "sum"): | |
# reduce = reduce_scatter + root gather | |
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/reduce.cc#L21 | |
data = reduce_scatter(data, op=op) | |
# 0 [_, 0_0+1_1+2_1+3_1, _, _] | |
# 1 [_, _, 0_2+1_2+2_2+3_2, _] | |
# 2 [_, _, _, 0_3+2_3+1_3+3_3] | |
# 3 [0_0+1_0+2_0+3_0, _, _, _] | |
return gather_(data, target_rank) | |
def mpi_frame(f, world_size: int = 4): | |
shared_mem = multiprocessing.Value("i", 0) | |
pipe_pairs = make_p2p_pipe(world_size) | |
processes = [] | |
for rank in range(world_size): | |
p = multiprocessing.Process( | |
target=f, | |
args=(rank, world_size, shared_mem, pipe_pairs), | |
) | |
processes.append(p) | |
p.start() | |
for p in processes: | |
p.join() | |
def test_gather(rank, world_size, signal_queue, pipe_pairs): | |
init_env(rank, world_size, signal_queue, pipe_pairs) | |
log_rank("\nTest gather") | |
barrier() | |
data = [rank * 10 + i for i in range(4)] | |
print(f"Rank {rank} previous data: {data}") | |
data = gather(data) | |
print(f"Rank {rank} data: {data}") | |
def test_scatter(rank, world_size, signal_queue, pipe_pairs): | |
init_env(rank, world_size, signal_queue, pipe_pairs) | |
log_rank("\nTest scatter") | |
barrier() | |
source_rank = 0 | |
data = [10 * x for x in range(world_size)] if rank == source_rank else None | |
print(f"Rank {rank} previous data: {data}") | |
data = scatter(data, source_rank=source_rank) | |
print(f"Rank {rank} data: {data}") | |
def test_broadcast(rank, world_size, signal_queue, pipe_pairs): | |
init_env(rank, world_size, signal_queue, pipe_pairs) | |
log_rank("\nTest broadcast") | |
barrier() | |
source_rank = 1 | |
data = [10 * x for x in range(world_size)] if rank == source_rank else None | |
print(f"Rank {rank} previous data: {data}") | |
data = broadcast(data, source_rank=source_rank) | |
print(f"Rank {rank} data: {data}") | |
def test_all_to_all(rank, world_size, signal_queue, pipe_pairs): | |
init_env(rank, world_size, signal_queue, pipe_pairs) | |
log_rank("\nTest all to all") | |
barrier() | |
data = [rank * 10 + x for x in range(world_size * 2)] | |
print(f"Rank {rank} previous data: {data}") | |
data = all_to_all(data) | |
print(f"Rank {rank} data: {data}") | |
def test_all_gather(rank, world_size, signal_queue, pipe_pairs): | |
init_env(rank, world_size, signal_queue, pipe_pairs) | |
log_rank("\nTest all gather") | |
barrier() | |
data = [rank * 10 + i for i in range(4)] | |
print(f"Rank {rank} previous data: {data}") | |
data = all_gather(data) | |
print(f"Rank {rank} data: {data}") | |
def test_reduce_scatter(rank, world_size, signal_queue, pipe_pairs): | |
init_env(rank, world_size, signal_queue, pipe_pairs) | |
log_rank("\nTest reduce scatter") | |
barrier() | |
data = [rank * 10 + x for x in range(world_size * 2)] | |
print(f"Rank {rank} previous data: {data}") | |
data = reduce_scatter(data, op="sum", slice=True) | |
print(f"Rank {rank} data: {data}") | |
def test_all_reduce(rank, world_size, signal_queue, pipe_pairs): | |
init_env(rank, world_size, signal_queue, pipe_pairs) | |
log_rank("\nTest all reduce") | |
barrier() | |
data = [rank * 10 + x for x in range(world_size * 2)] | |
print(f"Rank {rank} previous data: {data}") | |
data = all_reduce(data, op="mean") | |
print(f"Rank {rank} data: {data}") | |
def test_reduce(rank, world_size, signal_queue, pipe_pairs): | |
init_env(rank, world_size, signal_queue, pipe_pairs) | |
log_rank("\nTest reduce") | |
barrier() | |
source_rank = 0 | |
data = [rank * 10 + x for x in range(world_size * 2)] | |
print(f"Rank {rank} previous data: {data}") | |
data = reduce(data, target_rank=source_rank, op="sum") | |
print(f"Rank {rank} data: {data}") | |
if __name__ == "__main__": | |
world_size = 4 | |
mpi_frame(test_gather, world_size=world_size) | |
mpi_frame(test_scatter, world_size=world_size) | |
mpi_frame(test_broadcast, world_size=world_size) | |
mpi_frame(test_all_to_all, world_size=world_size) | |
mpi_frame(test_all_gather, world_size=world_size) | |
mpi_frame(test_reduce_scatter, world_size=world_size) | |
mpi_frame(test_all_reduce, world_size=world_size) | |
mpi_frame(test_reduce, world_size=world_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment