Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Created May 11, 2019 00:09
Show Gist options
  • Save wanchaol/e0abd4dac98325fcee2790ada8a4e21b to your computer and use it in GitHub Desktop.
Save wanchaol/e0abd4dac98325fcee2790ada8a4e21b to your computer and use it in GitHub Desktop.
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
def forward(self, x):
x = x.view(-1, 784)
h1 = self.fc1(x)
h2 = self.fc21(h1)
return torch.sigmoid(h2)
# helper function to get sum of List[Tensor]
def _sum_of_list(tensorlist):
s = 0
for t in tensorlist:
if isinstance(t, torch.Tensor):
s += t.sum()
return s
def clone_inputs(arg):
input = arg.detach().clone().requires_grad_()
return input, input
input_tensor = torch.rand((128, 1, 28, 28), requires_grad=True)
traced = torch.jit.trace(VAE(), input_tensor)
recording_inputs, recording_tensors = clone_inputs(input_tensor)
outputs = traced(recording_inputs)
l1 = _sum_of_list(outputs)
grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, allow_unused=True)
l2 = (_sum_of_list(grads) * l1)
grads2 = torch.autograd.grad(l2, recording_tensors, create_graph=True, allow_unused=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment