Last active
September 29, 2022 04:07
-
-
Save davidberard98/3c746cd0c8bd79d40353bb0b263f9518 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchdynamo | |
import argparse | |
import os | |
import logging | |
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler | |
# torchdynamo.config.verbose = True | |
# torchdynamo.config.log_level = logging.DEBUG | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--rank", type=int, required=True) | |
parser.add_argument("--world_size", type=int, required=True) | |
args = parser.parse_args() | |
def setup(): | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = "12355" | |
torch.distributed.init_process_group("nccl", rank=args.rank, world_size=args.world_size) | |
def cleanup(): | |
torch.distributed.destroy_process_group() | |
setup() | |
class MyModule(torch.nn.Module): | |
def __init__(self, x): | |
super(MyModule, self).__init__() | |
self.x = torch.nn.Parameter(x) | |
self.r = torch.nn.ReLU() | |
def forward(self, x): | |
y = self.r(self.x * x) | |
return torch.cos(y) | |
unwrapped = MyModule(torch.rand((2, 2), device=f"cuda:{args.rank}")) | |
wrapped = torch.nn.parallel.distributed.DistributedDataParallel(unwrapped) | |
@torchdynamo.optimize("inductor") | |
def fn(data): | |
for (x, y) in data: | |
z = wrapped(x) # part B | |
(z - y).square().sum().backward() | |
data = [tuple(torch.rand((2, 2), device=f"cuda:{args.rank}") for _ in range(2)) for __ in range(10)] | |
for i in range(10): | |
fn(data) | |
with profile( | |
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | |
record_shapes=True, | |
with_stack=True, | |
on_trace_ready=tensorboard_trace_handler( | |
f"./profile_results", | |
0, | |
use_gzip=True, | |
) | |
): | |
for i in range(10): | |
fn(data) | |
cleanup() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment