Skip to content

Instantly share code, notes, and snippets.

@scturtle
Last active January 20, 2025 09:03
Show Gist options
  • Save scturtle/c98a7b7a5fa4499a08f1d4bff0cb9daf to your computer and use it in GitHub Desktop.
Save scturtle/c98a7b7a5fa4499a08f1d4bff0cb9daf to your computer and use it in GitHub Desktop.
nccl in 500 LOCs
#!/usr/bin/env python3
# https://github.com/FateScript/experiments/blob/main/se/mpi/mpi.py
# https://github.com/facebookincubator/gloo/tree/main/gloo
import math
import multiprocessing
import os
import numpy as np
_P2P_PIPES = None
def make_p2p_pipe(num_processes: int):
pipe_pairs = {}
for i in range(num_processes):
for j in range(i + 1, num_processes):
(src, dst) = multiprocessing.Pipe() # bidirectional
pipe_pairs[(i, j)] = src # recv
pipe_pairs[(j, i)] = dst # send
return pipe_pairs
def set_pipes(pipes):
global _P2P_PIPES
_P2P_PIPES = pipes
def get_pipes():
global _P2P_PIPES
return _P2P_PIPES
def set_rank(rank):
os.environ["RANK"] = str(rank)
def get_rank() -> int:
return int(os.environ["RANK"])
def set_world_size(world_size):
os.environ["WORLD_SIZE"] = str(world_size)
def get_world_size() -> int:
return int(os.environ["WORLD_SIZE"])
def log_rank(*args):
cur_rank = get_rank()
if cur_rank == 0:
print(*args)
def init_env(rank, world_size, shared_mem, pipe_pairs):
# similar to context in gloo
set_rank(rank)
set_world_size(world_size)
set_pipes(pipe_pairs)
def barrier():
# dissemination barrier algorithm
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/barrier.cc#L18
rank = get_rank()
world_size = get_world_size()
pipes = get_pipes()
dist = 1
while dist < world_size:
to_rank = (rank + dist) % world_size
pipes[(rank, to_rank)].send(True)
from_rank = (rank - dist) % world_size
pipes[(rank, from_rank)].recv()
dist <<= 1
def gather(data, target_rank: int = 0):
# Collects data from multiple ranks into a single rank.
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/gather.cc#L18
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes()
if rank != target_rank:
pipes[(rank, target_rank)].send(data)
return None
else:
gathered_data = []
for src_rank in range(world_size):
if src_rank == rank:
gathered_data.append(data)
continue
recv_data = pipes[(rank, src_rank)].recv()
gathered_data.append(recv_data)
return concat_data(gathered_data, axis=-1)
def scatter(data, source_rank: int = 0):
# Distributes chunks of data from one rank to multiple ranks.
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/scatter.cc#L19
rank = get_rank()
pipe_pairs = get_pipes()
if rank == source_rank:
world_size = get_world_size()
assert len(data) % world_size == 0
scatter_data = [data[i::world_size] for i in range(world_size)]
for dst_rank, dst_data in enumerate(scatter_data):
if dst_rank == source_rank:
data = dst_data
continue
pipe_pairs[(source_rank, dst_rank)].send(dst_data)
return data
else:
data = pipe_pairs[(rank, source_rank)].recv()
return data
def elementwise_sum(data1, data2):
if isinstance(data1, np.ndarray):
return data1 + data2
else: # iterable
return type(data1)([x + y for x, y in zip(data1, data2)])
def elementwise_max(data1, data2):
if isinstance(data1, np.ndarray):
return np.maximum(data1, data2)
else: # iterable
return type(data1)([max(x, y) for x, y in zip(data1, data2)])
def elementwise_div(data, size: int):
if isinstance(data, np.ndarray):
return data / size
else: # iterable
return type(data)([x / size for x in data])
def concat_data(data_list, axis: int = -1):
if isinstance(data_list[0], np.ndarray):
return np.concatenate(data_list, axis=axis)
else: # iterable
return type(data_list[0])([x for data in data_list for x in data])
OP_FUNCS = {
"sum": elementwise_sum,
"max": elementwise_max,
}
def op_to_func(op: str):
return OP_FUNCS[op.lower()]
def broadcast(data, source_rank: int = 0):
# Sends data from one rank to all other ranks.
# Iterative halving/doubling algorithm, O(log(n)) time complexity
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/broadcast.cc#L20
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes()
v_rank = (rank - source_rank) % world_size
involved_rank = 1
for turn in range(math.ceil(math.log2(world_size))):
# 1st turn: rank0 -> rank1
# 2nd turn: rank0 -> rank2, rank1 -> rank3, and so on.
min_recv_rank = involved_rank # 2 ** turn
involved_rank <<= 1
if v_rank < involved_rank:
is_send_rank = v_rank < min_recv_rank
v_peer_rank = (v_rank + min_recv_rank) % involved_rank
peer_rank = (v_peer_rank + source_rank) % world_size
if is_send_rank:
pipes[(rank, peer_rank)].send(data)
else:
data = pipes[(rank, peer_rank)].recv()
else: # not involved in this turn
continue
return data
def all_to_all(data, axis: int = -1):
# Every rank sends data to and receives data from all other ranks.
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/alltoall.cc#L18
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes()
data_size = len(data)
assert data_size % world_size == 0, "Data size must be divisible by world_size."
chunk_size = data_size // world_size
data_list = [None for _ in range(world_size)]
data_list[rank] = data[rank * chunk_size: (rank + 1) * chunk_size]
# turn 1: 0 -> 1; 1 -> 2; 2 -> 3; 3 -> 0
# turn 2: 0 -> 2; 1 -> 3; 2 -> 0; 3 -> 1
# turn 3: 0 -> 3; 1 -> 0; 2 -> 0; 3 -> 2
for turn in range(1, world_size):
send_rank = (rank + turn) % world_size
recv_rank = (rank - turn) % world_size
send_data = data[send_rank * chunk_size: (send_rank + 1) * chunk_size]
pipes[(rank, send_rank)].send(send_data)
recv_data = pipes[(rank, recv_rank)].recv()
data_list[recv_rank] = recv_data
return concat_data(data_list, axis=axis)
def all_gather(data, axis: int = -1):
# Combines data from all ranks and makes the entire result available to every rank.
# all_gather = gather + broadcast
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/allgather.cc#L19
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes()
send_rank = (rank + 1) % world_size
recv_rank = (rank - 1) % world_size
data_list = [None for _ in range(world_size)]
data_list[rank] = data
send_idx = rank
# 0 -> 1; 1 -> 2; 2 -> 3; 3 -> 0
# turn 0: 0=[0,*3*] 1=[*0*,1] 2=[*1*,2] 3=[*2*,3]
# turn 1: 0=[0,*2*,3] 1=[0,1,*3*] 2=[*0*,1,2] 3=[*1*,2,3]
# turn 2: 0=[0,*1*,2,3] 1=[0,1,*2*,3] 2=[0,1,2,*3*] 3=[*0*,1,2,3]
for _ in range(world_size - 1):
pipes[(rank, send_rank)].send(data_list[send_idx])
recv_data = pipes[(rank, recv_rank)].recv()
send_idx = (send_idx - 1) % world_size
data_list[send_idx] = recv_data
return concat_data(data_list, axis=axis)
def reduce_scatter(data, op: str = "sum", slice: bool = False):
# Reduces data across ranks using a specified operation and scatters the reduced results among ranks.
# TODO: Halving Doubling https://github.com/facebookincubator/gloo/blob/main/gloo/reduce_scatter.h
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes()
if world_size == 1:
return [data] if slice else data
recv_rank, to_rank = (rank - 1) % world_size, (rank + 1) % world_size
chunk_size = math.ceil(len(data) / world_size)
func = elementwise_sum if op == "mean" else op_to_func(op)
# 0 -> 1 1 -> 2 2 -> 3 3 -> 0
# recv [3_3] [0_0] [1_1] [2_2]
# recv [2_2 + 3_2] [0_3 + 3_3] [0_0 + 1_0] [1_1 + 2_1]
# recv [1_1 + 2_1 + 3_1] [0_2 + 2_2 + 3_2] [1_3 + 0_3 + 3_3] [0_0 + 1_0 + 2_0]
for round in range(world_size - 1):
send_chunk = (rank - round) % world_size
send_data = data[send_chunk * chunk_size: (send_chunk + 1) * chunk_size]
pipes[(rank, to_rank)].send(send_data)
recv_chunk = (send_chunk - 1) % world_size
recv_data = pipes[(rank, recv_rank)].recv()
origin_data = data[recv_chunk * chunk_size: (recv_chunk + 1) * chunk_size]
data[recv_chunk * chunk_size: (recv_chunk + 1) * chunk_size] = func(recv_data, origin_data)
if op == "mean":
data = elementwise_div(data, world_size)
if slice:
# recv [0_0 + 1_0 + 2_0 + 3_0] [0_1 + 1_1 + 2_1 + 3_1] [0_2 + 1_2 + 2_2 + 3_2] [0_3 + 1_3 + 2_3 + 3_3]
send_data = data[to_rank * chunk_size: (to_rank + 1) * chunk_size]
pipes[(rank, to_rank)].send(send_data)
return pipes[(rank, recv_rank)].recv()
return data
def all_gather_(data):
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes()
recv_rank, to_rank = (rank - 1) % world_size, (rank + 1) % world_size
chunk_size = math.ceil(len(data) / world_size)
# send [1] [2] [3] [0]
# send [0] [1] [2] [3]
# send [3] [0] [1] [2]
for round in range(world_size - 1):
send_chunk = (rank + 1 - round) % world_size
send_data = data[send_chunk * chunk_size: (send_chunk + 1) * chunk_size]
pipes[(rank, to_rank)].send(send_data)
recv_chunk = (send_chunk - 1) % world_size
recv_data = pipes[(rank, recv_rank)].recv()
data[recv_chunk * chunk_size: (recv_chunk + 1) * chunk_size] = recv_data
return data
def all_reduce(data, op: str = "sum"):
# ring all-reduce = reduce-scatter + ring all-gather
# code: https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/allreduce.cc#L147
# There is also a 2D-ring all-reduce algorithm (Intra ring + Inter ring), like bcube algorithm.
# Here is the paper links:
# https://arxiv.org/pdf/1807.11205
# https://arxiv.org/pdf/1811.05233
data = reduce_scatter(data, op=op)
# 0 [_, 0_0+1_1+2_1+3_1, _, _]
# 1 [_, _, 0_2+1_2+2_2+3_2, _]
# 2 [_, _, _, 0_3+2_3+1_3+3_3]
# 3 [0_0+1_0+2_0+3_0, _, _, _]
data = all_gather_(data)
return data
def gather_(data, target_rank: int = 0):
rank, world_size, pipes = get_rank(), get_world_size(), get_pipes()
chunk_size = math.ceil(len(data) / world_size)
chunk_idx = (rank + 1) % world_size
if rank != target_rank:
send_data = data[chunk_idx * chunk_size: (chunk_idx + 1) * chunk_size]
pipes[(rank, target_rank)].send(send_data)
return None
else:
for from_rank in range(world_size):
if from_rank == rank:
continue
from_chunk_idx = (from_rank + 1) % world_size
recv_data = pipes[(rank, from_rank)].recv()
data[from_chunk_idx * chunk_size: (from_chunk_idx + 1) * chunk_size] = recv_data
return data
def reduce(data, target_rank: int = 0, op: str = "sum"):
# reduce = reduce_scatter + root gather
# https://github.com/facebookincubator/gloo/blob/81925d1c674c34f0dc34dd9a0f2151c1b6f701eb/gloo/reduce.cc#L21
data = reduce_scatter(data, op=op)
# 0 [_, 0_0+1_1+2_1+3_1, _, _]
# 1 [_, _, 0_2+1_2+2_2+3_2, _]
# 2 [_, _, _, 0_3+2_3+1_3+3_3]
# 3 [0_0+1_0+2_0+3_0, _, _, _]
return gather_(data, target_rank)
def mpi_frame(f, world_size: int = 4):
shared_mem = multiprocessing.Value("i", 0)
pipe_pairs = make_p2p_pipe(world_size)
processes = []
for rank in range(world_size):
p = multiprocessing.Process(
target=f,
args=(rank, world_size, shared_mem, pipe_pairs),
)
processes.append(p)
p.start()
for p in processes:
p.join()
def test_gather(rank, world_size, signal_queue, pipe_pairs):
init_env(rank, world_size, signal_queue, pipe_pairs)
log_rank("\nTest gather")
barrier()
data = [rank * 10 + i for i in range(4)]
print(f"Rank {rank} previous data: {data}")
data = gather(data)
print(f"Rank {rank} data: {data}")
def test_scatter(rank, world_size, signal_queue, pipe_pairs):
init_env(rank, world_size, signal_queue, pipe_pairs)
log_rank("\nTest scatter")
barrier()
source_rank = 0
data = [10 * x for x in range(world_size)] if rank == source_rank else None
print(f"Rank {rank} previous data: {data}")
data = scatter(data, source_rank=source_rank)
print(f"Rank {rank} data: {data}")
def test_broadcast(rank, world_size, signal_queue, pipe_pairs):
init_env(rank, world_size, signal_queue, pipe_pairs)
log_rank("\nTest broadcast")
barrier()
source_rank = 1
data = [10 * x for x in range(world_size)] if rank == source_rank else None
print(f"Rank {rank} previous data: {data}")
data = broadcast(data, source_rank=source_rank)
print(f"Rank {rank} data: {data}")
def test_all_to_all(rank, world_size, signal_queue, pipe_pairs):
init_env(rank, world_size, signal_queue, pipe_pairs)
log_rank("\nTest all to all")
barrier()
data = [rank * 10 + x for x in range(world_size * 2)]
print(f"Rank {rank} previous data: {data}")
data = all_to_all(data)
print(f"Rank {rank} data: {data}")
def test_all_gather(rank, world_size, signal_queue, pipe_pairs):
init_env(rank, world_size, signal_queue, pipe_pairs)
log_rank("\nTest all gather")
barrier()
data = [rank * 10 + i for i in range(4)]
print(f"Rank {rank} previous data: {data}")
data = all_gather(data)
print(f"Rank {rank} data: {data}")
def test_reduce_scatter(rank, world_size, signal_queue, pipe_pairs):
init_env(rank, world_size, signal_queue, pipe_pairs)
log_rank("\nTest reduce scatter")
barrier()
data = [rank * 10 + x for x in range(world_size * 2)]
print(f"Rank {rank} previous data: {data}")
data = reduce_scatter(data, op="sum", slice=True)
print(f"Rank {rank} data: {data}")
def test_all_reduce(rank, world_size, signal_queue, pipe_pairs):
init_env(rank, world_size, signal_queue, pipe_pairs)
log_rank("\nTest all reduce")
barrier()
data = [rank * 10 + x for x in range(world_size * 2)]
print(f"Rank {rank} previous data: {data}")
data = all_reduce(data, op="mean")
print(f"Rank {rank} data: {data}")
def test_reduce(rank, world_size, signal_queue, pipe_pairs):
init_env(rank, world_size, signal_queue, pipe_pairs)
log_rank("\nTest reduce")
barrier()
source_rank = 0
data = [rank * 10 + x for x in range(world_size * 2)]
print(f"Rank {rank} previous data: {data}")
data = reduce(data, target_rank=source_rank, op="sum")
print(f"Rank {rank} data: {data}")
if __name__ == "__main__":
world_size = 4
mpi_frame(test_gather, world_size=world_size)
mpi_frame(test_scatter, world_size=world_size)
mpi_frame(test_broadcast, world_size=world_size)
mpi_frame(test_all_to_all, world_size=world_size)
mpi_frame(test_all_gather, world_size=world_size)
mpi_frame(test_reduce_scatter, world_size=world_size)
mpi_frame(test_all_reduce, world_size=world_size)
mpi_frame(test_reduce, world_size=world_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment