Skip to content

Instantly share code, notes, and snippets.

@soumith
Last active June 25, 2018 20:19
Show Gist options
  • Save soumith/93c7e839786c78215a43132b17715974 to your computer and use it in GitHub Desktop.
Save soumith/93c7e839786c78215a43132b17715974 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as vision
class XlaMNIST(nn.Module):
def __init__(self):
super(XlaMNIST, self).__init__()
def forward(self, x):
y = x * x
z = y * 3 + x
return z
def main():
x = torch.randn(4, 1, 28, 28)
model = XlaMNIST()
traced_model = torch.jit.trace(x)(model)
fwd = traced_model._get_method('forward')
torch._C._jit_pass_decompose_addmm(fwd.graph)
print(fwd.graph)
gradient = torch._C._jit_differentiate(fwd.graph, [True for i in fwd.graph.inputs()])
print(gradient.f)
print(gradient.df)
torch._C._jit_pass_specialize_undef(gradient.df, [False for i in gradient.df.inputs()])
print("Pruned")
print(gradient.df)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment