Skip to content

Instantly share code, notes, and snippets.

@YouJiacheng
Created November 2, 2024 16:42
Show Gist options
  • Save YouJiacheng/d075abc0840cd56155dc8a56abde3317 to your computer and use it in GitHub Desktop.
Save YouJiacheng/d075abc0840cd56155dc8a56abde3317 to your computer and use it in GitHub Desktop.
import torch
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Shard
mesh_1d = init_device_mesh("cuda", (4,), mesh_dim_names=("shard",))
rank = mesh_1d.get_rank()
dtensors: list[DTensor] = []
for i in range(8):
match i % 4 == rank:
case True:
dtensor = DTensor.from_local(
torch.ones(768, 3072),
device_mesh=mesh_1d,
placements=[Shard(0)],
run_check=True,
shape=(768, 3072),
stride=(3072, 1),
)
case False:
dtensor = DTensor.from_local(
torch.ones(0, 3072),
device_mesh=mesh_1d,
placements=[Shard(0)],
run_check=True,
shape=(768, 3072),
stride=(3072, 1),
)
dtensors.append(dtensor)
print(dtensors[0].full_tensor())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment