Created
January 10, 2025 12:27
-
-
Save 3outeille/ff624bd02debaf795fd109d57ac59d31 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.distributed as dist | |
import lovely_tensors as lt; lt.monkey_patch() | |
def split_tensor(data: torch.Tensor, dim: int) -> torch.Tensor: | |
rank = dist.get_rank() | |
world_size = dist.get_world_size() | |
chunks = torch.chunk(data, world_size, dim=dim) | |
return chunks[rank].contiguous() | |
def column_linear_forward(X, local_W, group): | |
Y_local = X @ local_W.t() | |
return Y_local | |
def column_linear_backward(local_grad_Y, X, local_W, group): | |
# https://colossalai.org/docs/features/1D_tensor_parallel/#introduction | |
local_grad_X = local_grad_Y @ local_W | |
dist.all_reduce(local_grad_X, group=group) | |
grad_X = local_grad_X | |
grad_W = local_grad_Y.t() @ X | |
return grad_X, grad_W | |
def example_column_linear(): | |
group = dist.distributed_c10d._get_default_group() | |
X_ref = torch.arange(4 * 2, device="cuda", dtype=torch.float32, requires_grad=True).reshape(4, 2) | |
W_ref = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2) * 10 | |
X_ref.retain_grad() | |
W_ref.retain_grad() | |
dist.broadcast(X_ref, src=0, group=group) | |
dist.broadcast(W_ref, src=0, group=group) | |
X = X_ref.clone() | |
W = W_ref.clone() | |
Y_ref = X_ref @ W_ref.t() | |
# We will transpose for matrix multiplication. As a result, we need to split row-wise | |
Y_local = column_linear_forward(X, split_tensor(W, dim=0), group) | |
torch.testing.assert_close(Y_local, split_tensor(Y_ref, dim=1), rtol=1e-5, atol=1e-5) | |
# Backward | |
Y_ref.sum().backward() | |
local_grad_Y = torch.ones_like(Y_local) | |
grad_X, grad_W = column_linear_backward(local_grad_Y, X, split_tensor(W, dim=0), group) | |
torch.testing.assert_close(grad_X, X_ref.grad, rtol=1e-5, atol=1e-5) | |
torch.testing.assert_close(grad_W, split_tensor(W_ref.grad, dim=0), rtol=1e-5, atol=1e-5) | |
print(f"Rank {dist.get_rank()}: SUCCESS.") | |
def row_linear_forward(local_X, local_W, group): | |
Y_local = local_X @ local_W.t() | |
dist.all_reduce(Y_local, group=group) | |
Y = Y_local | |
return Y | |
def row_linear_backward_single(grad_Y, X, local_W, group): | |
local_grad_X = grad_Y @ local_W | |
grad_X = [torch.zeros_like(local_grad_X) for _ in range(dist.get_world_size(group))] | |
dist.all_gather(grad_X, local_grad_X, group=group) | |
grad_X = torch.cat(grad_X, dim=1) | |
grad_W = grad_Y.t() @ X | |
return grad_X, grad_W | |
def row_linear_backward(grad_Y, X, local_W, group): | |
local_grad_X = grad_Y @ local_W | |
grad_W = grad_Y.t() @ X | |
return local_grad_X, grad_W | |
def example_row_linear(): | |
group = dist.distributed_c10d._get_default_group() | |
X_ref = torch.arange(4 * 2, device="cuda", dtype=torch.float32, requires_grad=True).reshape(4, 2) | |
W_ref = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2) * 10 | |
X_ref.retain_grad() | |
W_ref.retain_grad() | |
dist.broadcast(X_ref, src=0, group=group) | |
dist.broadcast(W_ref, src=0, group=group) | |
X = X_ref.clone() | |
W = W_ref.clone() | |
Y_ref = X_ref @ W_ref.t() | |
Y = row_linear_forward(split_tensor(X, dim=1), split_tensor(W, dim=1), group) | |
torch.testing.assert_close(Y, Y_ref, rtol=1e-5, atol=1e-5) | |
# Backward | |
Y_ref.sum().backward() | |
grad_Y = torch.ones_like(Y) | |
grad_X, grad_W = row_linear_backward_single(grad_Y, X, split_tensor(W, dim=1), group) | |
torch.testing.assert_close(grad_X, X_ref.grad, rtol=1e-5, atol=1e-5) | |
torch.testing.assert_close(grad_W, W_ref.grad, rtol=1e-5, atol=1e-5) | |
print(f"Rank {dist.get_rank()}: SUCCESS.") | |
def example_gelu(): | |
from torch.nn.functional import gelu | |
X = torch.randn(4, 2, device="cuda", dtype=torch.float32) | |
W = torch.randn(2, 2, device="cuda", dtype=torch.float32) | |
W_0, W_1 = W.chunk(2, dim=1) | |
# Column linear | |
y_col_1 = torch.cat([gelu(X @ W_0), gelu(X @ W_1)], dim=1) | |
y_col_2 = gelu(torch.cat([X @ W_0, X @ W_1], dim=1)) | |
torch.testing.assert_close(y_col_1, y_col_2, rtol=1e-5, atol=1e-5) | |
# Row linear | |
X_0, X_1 = X.chunk(2, dim=1) | |
W_0, W_1 = W.chunk(2, dim=0) | |
y_row_1 = gelu(X_0 @ W_0) + gelu(X_1 @ W_1) | |
y_row_2 = gelu(X_0 @ W_0 + X_1 @ W_1) | |
torch.testing.assert_close(y_row_1, y_row_2, rtol=1e-5, atol=1e-5) | |
def example_column_row_linear(): | |
# torchrun --nproc_per_node=2 tp_all_reduce.py | |
group = dist.distributed_c10d._get_default_group() | |
X_ref = torch.arange(4 * 2, device="cuda", dtype=torch.float32, requires_grad=True).reshape(4, 2) | |
W_ref_layer1 = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2) * 10 | |
W_ref_layer2 = torch.arange(1, 5, device="cuda", dtype=torch.float32, requires_grad=True).reshape(2, 2) | |
X_ref.retain_grad() | |
W_ref_layer1.retain_grad() | |
W_ref_layer2.retain_grad() | |
dist.broadcast(X_ref, src=0, group=group) | |
dist.broadcast(W_ref_layer1, src=0, group=group) | |
dist.broadcast(W_ref_layer2, src=0, group=group) | |
X = X_ref.clone() | |
W_layer1 = W_ref_layer1.clone() | |
W_layer2 = W_ref_layer2.clone() | |
# Forward | |
Y_ref_linear1 = X_ref @ W_ref_layer1.t() | |
Y_ref_linear1.retain_grad() | |
# We will transpose for matrix multiplication. As a result, we need to split row-wise | |
Y_local_linear1 = column_linear_forward(X, split_tensor(W_layer1, dim=0), group) | |
torch.testing.assert_close(Y_local_linear1, split_tensor(Y_ref_linear1, dim=1), rtol=1e-5, atol=1e-5) | |
Y_local_linear2 = row_linear_forward(Y_local_linear1, split_tensor(W_ref_layer2, dim=1), group) | |
Y_ref_linear2 = Y_ref_linear1 @ W_ref_layer2.t() | |
torch.testing.assert_close(Y_local_linear2, Y_ref_linear2, rtol=1e-5, atol=1e-5) | |
# Backward | |
Y_ref_linear2.sum().backward() | |
grad_Y = torch.ones_like(Y_ref_linear2) | |
grad_X_linear2, grad_W_linear2 = row_linear_backward(grad_Y, Y_local_linear1, split_tensor(W_layer2, dim=1), group) | |
torch.testing.assert_close(grad_X_linear2, split_tensor(Y_ref_linear1.grad, dim=1), rtol=1e-5, atol=1e-5) | |
torch.testing.assert_close(grad_W_linear2, split_tensor(W_ref_layer2.grad, dim=1), rtol=1e-5, atol=1e-5) | |
grad_X, grad_W = column_linear_backward(grad_X_linear2, X, split_tensor(W_layer1, dim=0), group) | |
torch.testing.assert_close(grad_X, X_ref.grad, rtol=1e-5, atol=1e-5) | |
torch.testing.assert_close(grad_W, split_tensor(W_ref_layer1.grad, dim=0), rtol=1e-5, atol=1e-5) | |
if __name__ == "__main__": | |
dist.init_process_group("nccl", rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"])) | |
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | |
# example_column_linear() | |
# example_row_linear() | |
example_column_row_linear() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment