Skip to content

Instantly share code, notes, and snippets.

@S1ro1
Created July 29, 2025 00:40
Show Gist options
  • Save S1ro1/4fe38e3a0fff3d84314935a0e05aed9c to your computer and use it in GitHub Desktop.
Save S1ro1/4fe38e3a0fff3d84314935a0e05aed9c to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from torch.distributed.tensor.placement_types import Replicate, Shard
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.parallel import parallelize_module
def dist_print(*args, **kwargs):
if dist.get_rank() == 0:
print(*args, **kwargs)
class ColumnWiseParallel(torch.nn.Module):
def __init__(self, layer, mesh, output_layout=[Shard(1)], use_local_output=True, bias=False):
# GEMM: [m, n] [n, k] = [m, k]
# Colwise dist gemm: [m, n] [n, k /2] = [m, k/2] -> bias is [k/2]
# Output layout is Shard(1) as the result is [m, k/2]
# Can clean this up into a util function called partition_input
super().__init__()
self.layer = layer
self.weigh_shard_size = self.layer.weight.shape[0] // dist.get_world_size()
self.output_layout = output_layout
self.use_local_output = use_local_output
self.mesh = mesh
self.bias = bias
# Torch does x @ W.T, so we shard on dimension 0 actually
self.rank_weight = self.layer.weight[
self.weigh_shard_size * dist.get_rank() : self.weigh_shard_size * (dist.get_rank() + 1), :
]
self.weight = DTensor.from_local(self.rank_weight, self.mesh, [Shard(0)])
dist_print(f"rank {dist.get_rank()} has {self.rank_weight.shape} weights")
dist_print(f"rank {dist.get_rank()} has {self.weight.shape} weights")
if self.bias:
# Bias is also sharded across ranks, as result is [m, k/2]
self.bias_size = self.layer.bias.shape[0] // dist.get_world_size()
self.bias_weight = self.layer.bias[
self.bias_size * dist.get_rank() : self.bias_size * (dist.get_rank() + 1)
]
self.bias = DTensor.from_local(self.bias_weight, self.mesh, [Shard(0)])
def forward(self, x):
if not isinstance(x, DTensor):
x = DTensor.from_local(x, self.mesh, [Replicate()])
# no-op if already replicated
x = x.redistribute(placements=[Replicate()])
y = torch.matmul(x, self.weight.T)
if self.bias:
y = y + self.bias
z = y.redistribute(placements=self.output_layout)
if self.use_local_output:
return z.to_local()
# Ask Tianyu what's up with this
return z
if __name__ == "__main__":
batch_size = 8
# Initialize the process group
dist.init_process_group("nccl", init_method="env://")
# Create a device mesh
mesh = init_device_mesh("cuda", (dist.get_world_size(),))
torch.cuda.set_device(dist.get_rank())
# Create a model
model = torch.nn.Linear(8, 16, bias=False).cuda()
# we broadcast so each rank has exactly same weights
weight = torch.randn_like(model.weight).cuda()
dist.broadcast(weight, src=0)
model.weight = torch.nn.Parameter(weight)
# We want the output to be replicated across all ranks so we can test
model_parallel = ColumnWiseParallel(model, mesh, output_layout=[Replicate()])
x = torch.ones(batch_size, 8).cuda()
y = model_parallel(x)
z = model(x)
torch.testing.assert_close(y, z)
dist.destroy_process_group()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment