Skip to content

Instantly share code, notes, and snippets.

@soumith
Created June 25, 2018 20:57
Show Gist options
  • Save soumith/5cef8fce54c6ec41492e3afa1b70c2ef to your computer and use it in GitHub Desktop.
Save soumith/5cef8fce54c6ec41492e3afa1b70c2ef to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as vision
class XlaMNIST(nn.Module):
def __init__(self):
super(XlaMNIST, self).__init__()
# self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
# self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
# self.fc1 = nn.Linear(320, 50)
# self.fc2 = nn.Linear(50, 10)
def forward(self, x):
y = x * x
z = y * 3 + x
return z
# x = F.relu(F.max_pool2d(self.conv1(x), 2))
# x = F.relu(F.max_pool2d(self.conv2(x), 2))
# x = x.view(-1, 320)
# x = F.relu(self.fc1(x))
# x = self.fc2(x)
# return F.log_softmax(x, dim=1)
def main():
# x = torch.randn(4, 3, 224, 224)
# model = vision.models.resnet50()
x = torch.randn(4, 1, 28, 28)
model = XlaMNIST()
traced_model = torch.jit.trace(x)(model)
fwd = traced_model._get_method('forward')
torch._C._jit_pass_decompose_addmm(fwd.graph)
print(fwd.graph)
# successfully run forward pass
# out_xla = torch._C._to_xla_module(traced_model)(x)
# print( (out_xla - model(x)).abs().max().item())
gradient = torch._C._jit_differentiate(fwd.graph, [True for i in fwd.graph.inputs()])
print(gradient.f)
print(gradient.df)
print(len(gradient.df.inputs()))
torch._C._jit_pass_specialize_undef(gradient.df, [False for i in gradient.df.inputs()])
print("Pruned")
print(gradient.df)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment