Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Last active September 29, 2022 04:07
Show Gist options
  • Save davidberard98/3c746cd0c8bd79d40353bb0b263f9518 to your computer and use it in GitHub Desktop.
Save davidberard98/3c746cd0c8bd79d40353bb0b263f9518 to your computer and use it in GitHub Desktop.
import torch
import torchdynamo
import argparse
import os
import logging
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
# torchdynamo.config.verbose = True
# torchdynamo.config.log_level = logging.DEBUG
parser = argparse.ArgumentParser()
parser.add_argument("--rank", type=int, required=True)
parser.add_argument("--world_size", type=int, required=True)
args = parser.parse_args()
def setup():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group("nccl", rank=args.rank, world_size=args.world_size)
def cleanup():
torch.distributed.destroy_process_group()
setup()
class MyModule(torch.nn.Module):
def __init__(self, x):
super(MyModule, self).__init__()
self.x = torch.nn.Parameter(x)
self.r = torch.nn.ReLU()
def forward(self, x):
y = self.r(self.x * x)
return torch.cos(y)
unwrapped = MyModule(torch.rand((2, 2), device=f"cuda:{args.rank}"))
wrapped = torch.nn.parallel.distributed.DistributedDataParallel(unwrapped)
@torchdynamo.optimize("inductor")
def fn(data):
for (x, y) in data:
z = wrapped(x) # part B
(z - y).square().sum().backward()
data = [tuple(torch.rand((2, 2), device=f"cuda:{args.rank}") for _ in range(2)) for __ in range(10)]
for i in range(10):
fn(data)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
on_trace_ready=tensorboard_trace_handler(
f"./profile_results",
0,
use_gzip=True,
)
):
for i in range(10):
fn(data)
cleanup()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment