Skip to content

Instantly share code, notes, and snippets.

@ezyang
Created July 11, 2017 00:04
Show Gist options
  • Select an option

  • Save ezyang/baa09264096f4f09901a17f241e054e5 to your computer and use it in GitHub Desktop.

Select an option

Save ezyang/baa09264096f4f09901a17f241e054e5 to your computer and use it in GitHub Desktop.
Simple tracing JIT example
import torch
from torch.autograd import Variable, Function
import torch._C as _C
import sys
class Add(Function):
def forward(self, a, b):
return a.add(b)
def backward(self, grad_output):
return grad_output, grad_output
x = Variable(torch.Tensor([4]), requires_grad=True)
y = Variable(torch.Tensor([7]), requires_grad=True)
torch._C._tracer_enter((x,y))
#z = Add()(x,y)
z = x * y
trace = torch._C._tracer_exit((z,))
print(trace)
loss = z.sum().backward()
print(x.grad)
print(y.grad)
x.data[0] = 2
y.data[0] = 3
x2 = Variable(torch.Tensor([4]), requires_grad=True)
y2 = Variable(torch.Tensor([7]), requires_grad=True)
(z,) = z._execution_engine.run_forward(trace, (x2, y2))
z.sum().backward()
print(x2.grad)
print(y2.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment