Skip to content

Instantly share code, notes, and snippets.

@soumith
Created June 26, 2018 04:24
Show Gist options
  • Save soumith/8102ef39530bac09070912b1a5401d0f to your computer and use it in GitHub Desktop.
Save soumith/8102ef39530bac09070912b1a5401d0f to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
y = torch.sigmoid(x + x * x)
return y
def main():
x = torch.randn(4, 1, 28, 28, requires_grad=True)
model = Model()
traced_model = torch.jit.trace(x)(model)
fwd = traced_model._get_method('forward')
torch._C._jit_pass_decompose_addmm(fwd.graph)
print("forward graph before symbolic diff", fwd.graph)
# for now, all inputs require grad, not just parameters
inputs_require_grad = [True for i in fwd.graph.inputs()]
gradient = torch._C._jit_differentiate(fwd.graph, inputs_require_grad)
print("forward graph after symbolic diff", gradient.f)
defined_df_inputs = [True for i in gradient.df.inputs()] # all df inputs are defined (usual case)
torch._C._jit_pass_specialize_undef(gradient.df, defined_df_inputs)
print("gradient graph", gradient.df)
exec_f = torch._C.GraphExecutor(gradient.f, True)
exec_df = torch._C.GraphExecutor(gradient.df, True)
# forward function
inputs = [x]
raw_outputs = exec_f(*inputs)
if isinstance(raw_outputs, torch.Tensor):
raw_outputs = [raw_outputs]
outputs = raw_outputs[:gradient.f_real_outputs]
# backward function
grad_outputs = [torch.randn(4, 1, 28, 28)] # random grad_output
raw_grad_outputs = []
raw_grad_outputs += grad_outputs
raw_grad_outputs += [inputs[i] for i in gradient.df_input_captured_inputs]
raw_grad_outputs += [raw_outputs[i] for i in gradient.df_input_captured_outputs]
grad_input = exec_df(*raw_grad_outputs)
# forward + backward with regular autograd / torch
out_groundtruth = model(x)
out_groundtruth.backward(*grad_outputs)
# compare both
print("output_jit - output: ", (outputs[0] - out_groundtruth).abs().max())
print("gradinput_jit - gradinput: ", (grad_input - inputs[0].grad).abs().max())
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment