Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Created July 26, 2022 20:35
Show Gist options
  • Save wanchaol/0b6a33bcaa7aacc274553ec0add449ef to your computer and use it in GitHub Desktop.
Save wanchaol/0b6a33bcaa7aacc274553ec0add449ef to your computer and use it in GitHub Desktop.
mesh = DeviceMesh("cuda", [[0, 1], [2, 3]])
placements -> describes how to place the tensor to the device mesh
torch.randn(12, 8)
device mesh dim = 2, size(2, 2)
placements = [Shard(shard_dim=0), Replicate()]
torch.randn(6, 8) -> [0, 1]
torch.randn(6, 8) -> [2, 3]
placements = [Shard(1), Replicate()]
torch.randn(12, 4) -> [0, 1]
torch.randn(12, 4) -> [2, 3]
# XLA
sharding_spec = [None, 0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment