Skip to content

Instantly share code, notes, and snippets.

@3outeille
Created January 10, 2025 12:27
Show Gist options
  • Save 3outeille/ff624bd02debaf795fd109d57ac59d31 to your computer and use it in GitHub Desktop.
Save 3outeille/ff624bd02debaf795fd109d57ac59d31 to your computer and use it in GitHub Desktop.
import os
import torch
import torch.distributed as dist
import lovely_tensors as lt; lt.monkey_patch()
def split_tensor(data: torch.Tensor, dim: int) -> torch.Tensor:
rank = dist.get_rank()
world_size = dist.get_world_size()
chunks = torch.chunk(data, world_size, dim=dim)
return chunks[rank].contiguous()
def column_linear_forward(X, local_W, group):
Y_local = X @ local_W.t()
return Y_local
def column_linear_backward(local_grad_Y, X, local_W, group):
# https://colossalai.org/docs/features/1D_tensor_parallel/#introduction
local_grad_X = local_grad_Y @ local_W
dist.all_reduce(local_grad_X, group=group)
grad_X = local_grad_X
grad_W = local_grad_Y.t() @ X
return grad_X, grad_W
def example_column_linear():
group = dist.distributed_c10d._get_default_group()
X_ref = torch.arange(4 * 2, device="cuda", dtype=torch.float32, requires_grad=True).reshape(4, 2)
W_ref = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2) * 10
X_ref.retain_grad()
W_ref.retain_grad()
dist.broadcast(X_ref, src=0, group=group)
dist.broadcast(W_ref, src=0, group=group)
X = X_ref.clone()
W = W_ref.clone()
Y_ref = X_ref @ W_ref.t()
# We will transpose for matrix multiplication. As a result, we need to split row-wise
Y_local = column_linear_forward(X, split_tensor(W, dim=0), group)
torch.testing.assert_close(Y_local, split_tensor(Y_ref, dim=1), rtol=1e-5, atol=1e-5)
# Backward
Y_ref.sum().backward()
local_grad_Y = torch.ones_like(Y_local)
grad_X, grad_W = column_linear_backward(local_grad_Y, X, split_tensor(W, dim=0), group)
torch.testing.assert_close(grad_X, X_ref.grad, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(grad_W, split_tensor(W_ref.grad, dim=0), rtol=1e-5, atol=1e-5)
print(f"Rank {dist.get_rank()}: SUCCESS.")
def row_linear_forward(local_X, local_W, group):
Y_local = local_X @ local_W.t()
dist.all_reduce(Y_local, group=group)
Y = Y_local
return Y
def row_linear_backward_single(grad_Y, X, local_W, group):
local_grad_X = grad_Y @ local_W
grad_X = [torch.zeros_like(local_grad_X) for _ in range(dist.get_world_size(group))]
dist.all_gather(grad_X, local_grad_X, group=group)
grad_X = torch.cat(grad_X, dim=1)
grad_W = grad_Y.t() @ X
return grad_X, grad_W
def row_linear_backward(grad_Y, X, local_W, group):
local_grad_X = grad_Y @ local_W
grad_W = grad_Y.t() @ X
return local_grad_X, grad_W
def example_row_linear():
group = dist.distributed_c10d._get_default_group()
X_ref = torch.arange(4 * 2, device="cuda", dtype=torch.float32, requires_grad=True).reshape(4, 2)
W_ref = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2) * 10
X_ref.retain_grad()
W_ref.retain_grad()
dist.broadcast(X_ref, src=0, group=group)
dist.broadcast(W_ref, src=0, group=group)
X = X_ref.clone()
W = W_ref.clone()
Y_ref = X_ref @ W_ref.t()
Y = row_linear_forward(split_tensor(X, dim=1), split_tensor(W, dim=1), group)
torch.testing.assert_close(Y, Y_ref, rtol=1e-5, atol=1e-5)
# Backward
Y_ref.sum().backward()
grad_Y = torch.ones_like(Y)
grad_X, grad_W = row_linear_backward_single(grad_Y, X, split_tensor(W, dim=1), group)
torch.testing.assert_close(grad_X, X_ref.grad, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(grad_W, W_ref.grad, rtol=1e-5, atol=1e-5)
print(f"Rank {dist.get_rank()}: SUCCESS.")
def example_gelu():
from torch.nn.functional import gelu
X = torch.randn(4, 2, device="cuda", dtype=torch.float32)
W = torch.randn(2, 2, device="cuda", dtype=torch.float32)
W_0, W_1 = W.chunk(2, dim=1)
# Column linear
y_col_1 = torch.cat([gelu(X @ W_0), gelu(X @ W_1)], dim=1)
y_col_2 = gelu(torch.cat([X @ W_0, X @ W_1], dim=1))
torch.testing.assert_close(y_col_1, y_col_2, rtol=1e-5, atol=1e-5)
# Row linear
X_0, X_1 = X.chunk(2, dim=1)
W_0, W_1 = W.chunk(2, dim=0)
y_row_1 = gelu(X_0 @ W_0) + gelu(X_1 @ W_1)
y_row_2 = gelu(X_0 @ W_0 + X_1 @ W_1)
torch.testing.assert_close(y_row_1, y_row_2, rtol=1e-5, atol=1e-5)
def example_column_row_linear():
# torchrun --nproc_per_node=2 tp_all_reduce.py
group = dist.distributed_c10d._get_default_group()
X_ref = torch.arange(4 * 2, device="cuda", dtype=torch.float32, requires_grad=True).reshape(4, 2)
W_ref_layer1 = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2) * 10
W_ref_layer2 = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2)
X_ref.retain_grad()
W_ref_layer1.retain_grad()
W_ref_layer2.retain_grad()
dist.broadcast(X_ref, src=0, group=group)
dist.broadcast(W_ref_layer1, src=0, group=group)
dist.broadcast(W_ref_layer2, src=0, group=group)
X = X_ref.clone()
W_layer1 = W_ref_layer1.clone()
W_layer2 = W_ref_layer2.clone()
# Forward
Y_ref_linear1 = X_ref @ W_ref_layer1.t()
Y_ref_linear1.retain_grad()
# We will transpose for matrix multiplication. As a result, we need to split row-wise
Y_local_linear1 = column_linear_forward(X, split_tensor(W_layer1, dim=0), group)
torch.testing.assert_close(Y_local_linear1, split_tensor(Y_ref_linear1, dim=1), rtol=1e-5, atol=1e-5)
Y_local_linear2 = row_linear_forward(Y_local_linear1, split_tensor(W_ref_layer2, dim=1), group)
Y_ref_linear2 = Y_ref_linear1 @ W_ref_layer2.t()
torch.testing.assert_close(Y_local_linear2, Y_ref_linear2, rtol=1e-5, atol=1e-5)
# Backward
Y_ref_linear2.sum().backward()
grad_Y = torch.ones_like(Y_ref_linear2)
grad_X_linear2, grad_W_linear2 = row_linear_backward(grad_Y, Y_local_linear1, split_tensor(W_layer2, dim=1), group)
torch.testing.assert_close(grad_X_linear2, split_tensor(Y_ref_linear1.grad, dim=1), rtol=1e-5, atol=1e-5)
torch.testing.assert_close(grad_W_linear2, split_tensor(W_ref_layer2.grad, dim=1), rtol=1e-5, atol=1e-5)
grad_X, grad_W = column_linear_backward(grad_X_linear2, X, split_tensor(W_layer1, dim=0), group)
torch.testing.assert_close(grad_X, X_ref.grad, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(grad_W, split_tensor(W_ref_layer1.grad, dim=0), rtol=1e-5, atol=1e-5)
if __name__ == "__main__":
dist.init_process_group("nccl", rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"]))
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
# example_column_linear()
# example_row_linear()
example_column_row_linear()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment