Created
July 29, 2025 00:40
-
-
Save S1ro1/4fe38e3a0fff3d84314935a0e05aed9c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.distributed.tensor.placement_types import Replicate, Shard | |
import torch.distributed as dist | |
from torch.distributed.device_mesh import init_device_mesh | |
from torch.distributed.tensor import DTensor | |
from torch.distributed.tensor.parallel import parallelize_module | |
def dist_print(*args, **kwargs): | |
if dist.get_rank() == 0: | |
print(*args, **kwargs) | |
class ColumnWiseParallel(torch.nn.Module): | |
def __init__(self, layer, mesh, output_layout=[Shard(1)], use_local_output=True, bias=False): | |
# GEMM: [m, n] [n, k] = [m, k] | |
# Colwise dist gemm: [m, n] [n, k /2] = [m, k/2] -> bias is [k/2] | |
# Output layout is Shard(1) as the result is [m, k/2] | |
# Can clean this up into a util function called partition_input | |
super().__init__() | |
self.layer = layer | |
self.weigh_shard_size = self.layer.weight.shape[0] // dist.get_world_size() | |
self.output_layout = output_layout | |
self.use_local_output = use_local_output | |
self.mesh = mesh | |
self.bias = bias | |
# Torch does x @ W.T, so we shard on dimension 0 actually | |
self.rank_weight = self.layer.weight[ | |
self.weigh_shard_size * dist.get_rank() : self.weigh_shard_size * (dist.get_rank() + 1), : | |
] | |
self.weight = DTensor.from_local(self.rank_weight, self.mesh, [Shard(0)]) | |
dist_print(f"rank {dist.get_rank()} has {self.rank_weight.shape} weights") | |
dist_print(f"rank {dist.get_rank()} has {self.weight.shape} weights") | |
if self.bias: | |
# Bias is also sharded across ranks, as result is [m, k/2] | |
self.bias_size = self.layer.bias.shape[0] // dist.get_world_size() | |
self.bias_weight = self.layer.bias[ | |
self.bias_size * dist.get_rank() : self.bias_size * (dist.get_rank() + 1) | |
] | |
self.bias = DTensor.from_local(self.bias_weight, self.mesh, [Shard(0)]) | |
def forward(self, x): | |
if not isinstance(x, DTensor): | |
x = DTensor.from_local(x, self.mesh, [Replicate()]) | |
# no-op if already replicated | |
x = x.redistribute(placements=[Replicate()]) | |
y = torch.matmul(x, self.weight.T) | |
if self.bias: | |
y = y + self.bias | |
z = y.redistribute(placements=self.output_layout) | |
if self.use_local_output: | |
return z.to_local() | |
# Ask Tianyu what's up with this | |
return z | |
if __name__ == "__main__": | |
batch_size = 8 | |
# Initialize the process group | |
dist.init_process_group("nccl", init_method="env://") | |
# Create a device mesh | |
mesh = init_device_mesh("cuda", (dist.get_world_size(),)) | |
torch.cuda.set_device(dist.get_rank()) | |
# Create a model | |
model = torch.nn.Linear(8, 16, bias=False).cuda() | |
# we broadcast so each rank has exactly same weights | |
weight = torch.randn_like(model.weight).cuda() | |
dist.broadcast(weight, src=0) | |
model.weight = torch.nn.Parameter(weight) | |
# We want the output to be replicated across all ranks so we can test | |
model_parallel = ColumnWiseParallel(model, mesh, output_layout=[Replicate()]) | |
x = torch.ones(batch_size, 8).cuda() | |
y = model_parallel(x) | |
z = model(x) | |
torch.testing.assert_close(y, z) | |
dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment