Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Last active December 8, 2022 18:52
Show Gist options
  • Save davidberard98/482a413f7b01e4873e4c9ced2e1bafe4 to your computer and use it in GitHub Desktop.
Save davidberard98/482a413f7b01e4873e4c9ced2e1bafe4 to your computer and use it in GitHub Desktop.
from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.debug_utils import same_two_models
import logging
torch._dynamo.config.log_level = logging.DEBUG
# REPLACEABLE COMMENT FOR TESTING PURPOSES
args = [((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192, 4096), (4096, 1), torch.float32, 'cuda', True), ((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192, 4096), (4096, 1), torch.float32, 'cuda', True), ((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((92, 4, 2048), (1, 188416, 92), torch.float32, 'cuda', True), ((4, 4, 2048), (8192, 2048, 1), torch.float32, 'cuda', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, self_model_lstm_lstm_flat_weights_0_ : torch.Tensor, self_model_lstm_lstm_flat_weights_1_ : torch.Tensor, self_model_lstm_lstm_flat_weights_2_ : torch.Tensor, self_model_lstm_lstm_flat_weights_3_ : torch.Tensor, self_model_lstm_lstm_flat_weights_4_ : torch.Tensor, self_model_lstm_lstm_flat_weights_5_ : torch.Tensor, self_model_lstm_lstm_flat_weights_6_ : torch.Tensor, self_model_lstm_lstm_flat_weights_7_ : torch.Tensor, self_model_lstm_lstm_flat_weights_8_ : torch.Tensor, self_model_lstm_lstm_flat_weights_9_ : torch.Tensor, self_model_lstm_lstm_flat_weights_10_ : torch.Tensor, self_model_lstm_lstm_flat_weights_11_ : torch.Tensor, self_model_lstm_lstm_flat_weights_12_ : torch.Tensor, self_model_lstm_lstm_flat_weights_13_ : torch.Tensor, self_model_lstm_lstm_flat_weights_14_ : torch.Tensor, self_model_lstm_lstm_flat_weights_15_ : torch.Tensor, permute, zeros):
zeros_1 = torch.zeros(4, 4, 2048, dtype = torch.float32, device = device(type='cuda', index=0))
lstm = torch.lstm(permute, (zeros, zeros_1), [self_model_lstm_lstm_flat_weights_0_, self_model_lstm_lstm_flat_weights_1_, self_model_lstm_lstm_flat_weights_2_, self_model_lstm_lstm_flat_weights_3_, self_model_lstm_lstm_flat_weights_4_, self_model_lstm_lstm_flat_weights_5_, self_model_lstm_lstm_flat_weights_6_, self_model_lstm_lstm_flat_weights_7_, self_model_lstm_lstm_flat_weights_8_, self_model_lstm_lstm_flat_weights_9_, self_model_lstm_lstm_flat_weights_10_, self_model_lstm_lstm_flat_weights_11_, self_model_lstm_lstm_flat_weights_12_, self_model_lstm_lstm_flat_weights_13_, self_model_lstm_lstm_flat_weights_14_, self_model_lstm_lstm_flat_weights_15_], True, 2, 0.0, True, True, False); permute = zeros = zeros_1 = self_model_lstm_lstm_flat_weights_0_ = self_model_lstm_lstm_flat_weights_1_ = self_model_lstm_lstm_flat_weights_2_ = self_model_lstm_lstm_flat_weights_3_ = self_model_lstm_lstm_flat_weights_4_ = self_model_lstm_lstm_flat_weights_5_ = self_model_lstm_lstm_flat_weights_6_ = self_model_lstm_lstm_flat_weights_7_ = self_model_lstm_lstm_flat_weights_8_ = self_model_lstm_lstm_flat_weights_9_ = self_model_lstm_lstm_flat_weights_10_ = self_model_lstm_lstm_flat_weights_11_ = self_model_lstm_lstm_flat_weights_12_ = self_model_lstm_lstm_flat_weights_13_ = self_model_lstm_lstm_flat_weights_14_ = self_model_lstm_lstm_flat_weights_15_ = None
return (lstm,)
mod = Repro()
# original minified test
'''
opt_mod = torch._dynamo.optimize("aot_eager")(mod)
from contextlib import nullcontext
# with torch.cuda.amp.autocast(enabled=False):
with nullcontext():
ref = run_fwd_maybe_bwd(mod, args)
res = run_fwd_maybe_bwd(opt_mod, args)
'''
mode = torch._subclasses.FakeTensorMode()
def convert_args(args):
return [mode.from_tensor(t) for t in args]
print(" Sanity check (torch.add(x, y) for fake tensor x, y)")
fn_args = [torch.rand(4, 4).cuda() for _ in range(2)]
def fn(x, y):
return torch.add(x, y)
fn(*convert_args(fn_args))
print(" Sanity check passed")
print(" Test torch.lstm")
mod(*[mode.from_tensor(t) for t in args])
print(" torch.lstm failed")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment