Skip to content

Instantly share code, notes, and snippets.

@soumith
Created July 16, 2018 14:25
Show Gist options
  • Save soumith/25fb1d33ab1e9f8a2efe6d513d924a55 to your computer and use it in GitHub Desktop.
Save soumith/25fb1d33ab1e9f8a2efe6d513d924a55 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
# self.conv1 = nn.Conv2d(1, 31, 3, 2)
self.bn1 = nn.BatchNorm2d(1)
def forward(self, x):
# y = self.conv1(torch.sigmoid(x + x * x))
# y = F.max_pool2d(x, 2)
# y = self.conv1(x)
y = self.bn1(x)
return y
def main():
inputs = [torch.randn(4, 1, 28, 28, requires_grad=True)]
# model = nn.BatchNorm2d(1) # Model()
import torchvision
inputs = [torch.randn(4, 3, 224, 224, requires_grad=True)]
model = torchvision.models.resnet50()
inputs_params = inputs + list(model.parameters()) + list(model._all_buffers())
traced_model = torch.jit.trace(*inputs)(model)
fwd = traced_model._get_method('forward')
torch._C._jit_pass_decompose_addmm(fwd.graph)
torch._C._jit_pass_dce(fwd.graph)
torch._C._jit_pass_constant_fold(fwd.graph)
torch._C._jit_pass_dce(fwd.graph)
# print("forward graph before unwrap", fwd.graph)
torch._C._jit_pass_unwrap_buffered_functions(fwd.graph)
print("forward graph before symbolic diff", fwd.graph)
# for now, all inputs require grad, not just parameters
inputs_require_grad = [i.requires_grad for i in inputs_params]
gradient = torch._C._jit_differentiate(fwd.graph, inputs_require_grad)
# print(list(gradient.df.inputs()))
defined_df_inputs = [True for i in gradient.df.inputs()] # all df inputs are defined (usual case)
torch._C._jit_pass_specialize_undef(gradient.df, defined_df_inputs)
torch._C._jit_pass_dce(gradient.f)
torch._C._jit_pass_dce(gradient.df)
print("forward graph after symbolic diff", gradient.f)
print("gradient graph", gradient.df)
exec_f = torch._C.GraphExecutor(gradient.f, False)
exec_df = torch._C.GraphExecutor(gradient.df, False)
# forward function
# print('inputs_params_sizes: ', [i.size() for i in inputs_params])
raw_outputs = exec_f(*inputs_params)
if isinstance(raw_outputs, torch.Tensor):
raw_outputs = [raw_outputs]
# print('raw_output_sizes: ', [i.size() for i in raw_outputs])
outputs = raw_outputs[:gradient.f_real_outputs]
# print([i.size() for i in outputs])
# backward function
grad_outputs = []
for o in raw_outputs:
if o.dtype == torch.float32 or o.dtype == torch.float64:
grad_outputs += [torch.randn(*o.shape, dtype=o.dtype)]
elif o.dtype == torch.int64:
# TODO remove this, we shouldn't be needing to pass grad_output for long types
grad_outputs += [torch.empty_like(o)]
else:
raise RuntimeError("Unsupported type: ", o.dtype)
raw_grad_outputs = []
raw_grad_outputs += grad_outputs
raw_grad_outputs += [inputs_params[i] for i in gradient.df_input_captured_inputs]
# print(gradient.df_input_captured_outputs)
# print(len(raw_outputs))
raw_grad_outputs += [raw_outputs[i] for i in gradient.df_input_captured_outputs]
grad_inputs = exec_df(*raw_grad_outputs)
# forward + backward with regular autograd / torch
out_groundtruth = model(*inputs)
out_groundtruth.backward(*grad_outputs[:gradient.f_real_outputs])
# compare both
print("output_jit - output: ", (outputs[0] - out_groundtruth).abs().max().item())
print("gradinput_jit - gradinput: ", (grad_inputs[0] - inputs[0].grad).abs().max().item())
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment