Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Created January 29, 2020 18:11
Show Gist options
  • Save wanchaol/dd9613215d324dc2845c748213b6ccc5 to your computer and use it in GitHub Desktop.
Save wanchaol/dd9613215d324dc2845c748213b6ccc5 to your computer and use it in GitHub Desktop.
import torch
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def my_method(self):
# type: () -> Tensor
pass
class MyScriptModule(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.a = torch.randn(10)
@torch.jit.script_method
def my_method(self):
return self.a
@torch.jit.script
class MyClass(object):
def my_method(self):
# type: () -> Tensor
return torch.randn(10)
@torch.jit.ignore
def mod_init():
# type: () -> ModuleInterface
return MyScriptModule()
@torch.jit.script
def test_script_mod():
# type: () -> ModuleInterface
return mod_init()
print(test_script_mod.graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment