Skip to content

Instantly share code, notes, and snippets.

@jamesr66a
Created June 21, 2018 22:24
Show Gist options
  • Select an option

  • Save jamesr66a/46dfc27ef7229e65066c83ef4e28b1cb to your computer and use it in GitHub Desktop.

Select an option

Save jamesr66a/46dfc27ef7229e65066c83ef4e28b1cb to your computer and use it in GitHub Desktop.
class TracedModule(torch.nn.Module):
def forward(self, x):
x = x.type(torch.float32)
return torch.floor(torch.sqrt(x) / 5.)
tm = torch.jit.trace(torch.rand(5))(TracedModule())
class ScriptModule(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x):
r = -x
if torch.fmod(x, 2.0) == 0.0:
r = x / 2.0
return r
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment