Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Last active December 9, 2022 21:38
Show Gist options
  • Save davidberard98/982591d42acb9c39e0b8ecfd10ac7dda to your computer and use it in GitHub Desktop.
Save davidberard98/982591d42acb9c39e0b8ecfd10ac7dda to your computer and use it in GitHub Desktop.
from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.debug_utils import same_two_models
import logging
torch._dynamo.config.log_level = logging.DEBUG
# REPLACEABLE COMMENT FOR TESTING PURPOSES
args = [((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192, 4096), (4096, 1), torch.float32, 'cuda', True), ((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192, 4096), (4096, 1), torch.float32, 'cuda', True), ((8192, 2048), (2048, 1), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((8192,), (1,), torch.float32, 'cuda', True), ((92, 4, 2048), (1, 188416, 92), torch.float32, 'cuda', True), ((4, 4, 2048), (8192, 2048, 1), torch.float32, 'cuda', False)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
zeros_1 = torch.zeros(4, 4, 2048, dtype = torch.float32, device = device(type='cuda', index=0))
args.append(zeros_1)
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, self_model_lstm_lstm_flat_weights_0_ : torch.Tensor, self_model_lstm_lstm_flat_weights_1_ : torch.Tensor, self_model_lstm_lstm_flat_weights_2_ : torch.Tensor, self_model_lstm_lstm_flat_weights_3_ : torch.Tensor, self_model_lstm_lstm_flat_weights_4_ : torch.Tensor, self_model_lstm_lstm_flat_weights_5_ : torch.Tensor, self_model_lstm_lstm_flat_weights_6_ : torch.Tensor, self_model_lstm_lstm_flat_weights_7_ : torch.Tensor, self_model_lstm_lstm_flat_weights_8_ : torch.Tensor, self_model_lstm_lstm_flat_weights_9_ : torch.Tensor, self_model_lstm_lstm_flat_weights_10_ : torch.Tensor, self_model_lstm_lstm_flat_weights_11_ : torch.Tensor, self_model_lstm_lstm_flat_weights_12_ : torch.Tensor, self_model_lstm_lstm_flat_weights_13_ : torch.Tensor, self_model_lstm_lstm_flat_weights_14_ : torch.Tensor, self_model_lstm_lstm_flat_weights_15_ : torch.Tensor, permute, zeros, zeros_1):
# zeros_1 = torch.zeros(4, 4, 2048, dtype = torch.float32, device = device(type='cuda', index=0))
lstm = torch.lstm(permute, (zeros, zeros_1), [self_model_lstm_lstm_flat_weights_0_, self_model_lstm_lstm_flat_weights_1_, self_model_lstm_lstm_flat_weights_2_, self_model_lstm_lstm_flat_weights_3_, self_model_lstm_lstm_flat_weights_4_, self_model_lstm_lstm_flat_weights_5_, self_model_lstm_lstm_flat_weights_6_, self_model_lstm_lstm_flat_weights_7_, self_model_lstm_lstm_flat_weights_8_, self_model_lstm_lstm_flat_weights_9_, self_model_lstm_lstm_flat_weights_10_, self_model_lstm_lstm_flat_weights_11_, self_model_lstm_lstm_flat_weights_12_, self_model_lstm_lstm_flat_weights_13_, self_model_lstm_lstm_flat_weights_14_, self_model_lstm_lstm_flat_weights_15_], True, 2, 0.0, True, True, False); permute = zeros = zeros_1 = self_model_lstm_lstm_flat_weights_0_ = self_model_lstm_lstm_flat_weights_1_ = self_model_lstm_lstm_flat_weights_2_ = self_model_lstm_lstm_flat_weights_3_ = self_model_lstm_lstm_flat_weights_4_ = self_model_lstm_lstm_flat_weights_5_ = self_model_lstm_lstm_flat_weights_6_ = self_model_lstm_lstm_flat_weights_7_ = self_model_lstm_lstm_flat_weights_8_ = self_model_lstm_lstm_flat_weights_9_ = self_model_lstm_lstm_flat_weights_10_ = self_model_lstm_lstm_flat_weights_11_ = self_model_lstm_lstm_flat_weights_12_ = self_model_lstm_lstm_flat_weights_13_ = self_model_lstm_lstm_flat_weights_14_ = self_model_lstm_lstm_flat_weights_15_ = None
return (lstm,)
mod = Repro()
mode = torch._subclasses.FakeTensorMode()
def convert_args(args):
return [mode.from_tensor(t) for t in args]
def convert_args_functional(args):
return [torch._to_functional_tensor(mode.from_tensor(t)) for t in args]
traced_mod = torch.fx.symbolic_trace(mod)
print(traced_mod.graph)
print("~~~~~~~~LSTM sanity check (FakeTensors)")
converted_args = convert_args(args)
with mode:
# mod(*[mode.from_tensor(t) for t in args])
# mod(*[mode2.from_tensor(t) for t in args])
# mod(*args)
traced_mod(*converted_args)
print("~~~~~~~~LSTM sanity check PASSED")
print("~~~~~~~~LSTM check with functional tensors")
# converted_args = convert_args(args)
converted_args = convert_args_functional(args)
with mode:
traced_mod(*converted_args)
print("~~~~~~~~LSTM with functional tensors PASSED")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment