Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created February 8, 2023 21:20
Show Gist options
  • Save davidberard98/d242fce6e6e0401579ca96dd638a6f5c to your computer and use it in GitHub Desktop.
Save davidberard98/d242fce6e6e0401579ca96dd638a6f5c to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
# LM example
class EncoderDecoderLanguageModel(nn.Module):
def __init__(self):
super().__init__()
@torch.jit.export
def prepare_inputs(self, var1: torch.Tensor, var2: torch.Tensor):
return {
"var1": var1,
"var2": var2,
}
def forward(self, var1 = None, var2 = None):
return var1*var2
# Parent class example
class Parent(nn.Module):
def __init__(self, model):
super().__init__()
self.model = torch.jit.script(model)
self.x = torch.rand((2, 2))
self.y = torch.rand((2, 2))
def forward(self):
kwargs = self.model.prepare_inputs(self.x, self.y)
return self.model(**kwargs) ### <---- this is the issue
# Initialization of Torchscriptable model
lm = EncoderDecoderLanguageModel()
# # Initialization of parent class that will call forward on torchscriptable model
parent_class = Parent(model=lm)
parent_class_traced = torch.jit.trace(parent_class, ())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment