Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Last active September 23, 2022 01:22
Show Gist options
  • Save davidberard98/b393950d9736f23db3fbecd43d22ae73 to your computer and use it in GitHub Desktop.
Save davidberard98/b393950d9736f23db3fbecd43d22ae73 to your computer and use it in GitHub Desktop.
import torch
import torchdynamo
import os
import logging
torchdynamo.config.verbose = True
torchdynamo.config.log_level = logging.DEBUG
def setup():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group("nccl", rank=0, world_size=1)
def cleanup():
torch.distributed.destroy_process_group()
setup()
class MyModule(torch.nn.Module):
def __init__(self, x):
super(MyModule, self).__init__()
self.x = torch.nn.Parameter(x)
self.r = torch.nn.ReLU()
def forward(self, x):
y = self.r(self.x * x)
print("Hello world") # line A
return torch.cos(y)
unwrapped = MyModule(torch.rand((2, 2), device="cuda"))
wrapped = torch.nn.parallel.distributed.DistributedDataParallel(unwrapped)
@torchdynamo.optimize("aot_eager")
def fn(data):
for i, (x, y) in enumerate(data):
print("i: ", i)
z = wrapped(x) # part B
(z - y).square().sum().backward()
data = [(torch.rand((2, 2), device='cuda') for _ in range(2)) for __ in range(10)]
fn(data)
cleanup()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment