Skip to content

Instantly share code, notes, and snippets.

@joecummings
Last active February 8, 2023 18:26
Show Gist options
  • Save joecummings/60f413cc0400a501d55c8672d8b5b393 to your computer and use it in GitHub Desktop.
Save joecummings/60f413cc0400a501d55c8672d8b5b393 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
# LM example
class EncoderDecoderLanguageModel(nn.Module):
def __init__(self):
super().__init__()
def prepare_inputs(self, var1, var2):
return {
"var1": var1,
"var2": var2,
}
def forward(self, var1 = None, var2 = None):
return var1*var2
# Parent class example
class Parent(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self):
kwargs = self.model.prepare_inputs(1, 0)
return self.model(**kwargs) ### <---- this is the issue
# Initialization of Torchscriptable model
lm = EncoderDecoderLanguageModel()
# # Initialization of parent class that will call forward on torchscriptable model
parent_class = Parent(model=lm)
torch.jit.script(parent_class)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment