Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Created July 25, 2023 08:10
Show Gist options
  • Save wanchaol/c40bf988fa8a822d93527dcbbb7eae3e to your computer and use it in GitHub Desktop.
Save wanchaol/c40bf988fa8a822d93527dcbbb7eae3e to your computer and use it in GitHub Desktop.
import torch.distributed as dist
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
class TestDTensorCompile(DTensorTestBase):
def setUp(self):
super().setUp()
@property
def world_size(self) -> int:
return 2
@with_comms
def test_dtensor_fullgraph(self):
class SimpleMLP(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.net1 = torch.nn.Linear(5, 1024, device=device)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(1024, 4, device=device)
def forward(self, x):
return self.net2(F.relu(self.net1(x)))
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
model = SimpleMLP(self.device_type)
model = parallelize_module(model, mesh, PairwiseParallel())
inp = torch.rand(20, 5, device=self.device_type)
out = model(inp)
compiled_mod = torch.compile(model, backend="eager", fullgraph=True)
compiled_out = compiled_mod(inp)
self.assertEqual(compiled_out, out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment