Skip to content

Instantly share code, notes, and snippets.

@soumith
Created August 6, 2018 05:49
Show Gist options
  • Save soumith/48eb59574c30efa5f9c3f822e3176c5f to your computer and use it in GitHub Desktop.
Save soumith/48eb59574c30efa5f9c3f822e3176c5f to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class XlaMNIST(nn.Module):
def __init__(self):
super(XlaMNIST, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.bn1 = nn.BatchNorm2d(10)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.bn2 = nn.BatchNorm2d(20)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = self.bn1(x)
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = self.bn2(x)
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# return F.log_softmax(x, dim=1)
# input = torch.randn(1, 4, requires_grad=True)
# model = nn.Linear(4, 20)
# input = torch.randn(4, 3, 224, 224, requires_grad=True)
# model = torchvision.models.resnet50()
input = torch.randn(4, 1, 28, 28, requires_grad=True)
model = XlaMNIST()
# basic conversion
input_xla = torch._C.XLATensor(input)
print(type(input_xla))
#print('printing XLA Tensor: ')
#print(input_xla)
print('')
input_back = input_xla.to_tensor()
print("difference of transfer + back: " , (input - input_back).abs().max().item())
# build xla model
traced_model = torch.jit.trace(input)(model)
xla_model = torch._C.XlaModule(traced_model, [input])
# run forward
output_xla = xla_model(input_xla)
output = model(input)
print("difference of output: " , (output - output_xla.to_tensor()).abs().max().item())
# run backward
grad_output = torch.randn(*output.shape) # random gradients
grad_output_xla = torch._C.XLATensor(grad_output)
output.backward(grad_output)
xla_model.backward(grad_output_xla)
diff = (input.grad - input_xla.grad.to_tensor())
print("difference of grad_input: " , 'absmax: ', diff.abs().max().item(),
'min: ', diff.min().item(), 'max: ', diff.max().item(),
'mean: ', diff.mean().item(),
'median: ', diff.median().item(),
'stdv: ', diff.std().item())
params = list(model.parameters())
params_xla = xla_model.parameters()
for param, param_xla in zip(params, params_xla):
print('param diff: ', (param.grad - param_xla.grad.to_tensor()).abs().max().item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment