Created
November 20, 2017 20:26
-
-
Save zou3519/ae272c821f4433c83b7890af51084c91 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.autograd import Variable | |
import torch.autograd as autograd | |
import torch.nn as nn | |
import torch.nn.functional as F | |
v_in = Variable(torch.Tensor([0.1, 0.1]).view(1, 2, 1, 1), requires_grad=True) | |
f1 = lambda x: x * 2 | |
f2 = nn.Conv2d(2, 1, 1, 1) | |
grad_out = Variable(torch.ones(1,1,1,1)) | |
gradient = autograd.grad(outputs=f2(f1(v_in)), inputs=v_in, | |
grad_outputs=grad_out, | |
create_graph=True, retain_graph=True, | |
only_inputs=True)[0] | |
gradient.sum().backward() # undefined tensor error | |
# the following checks that a change in `v_in` changes `gradient` | |
# so the conclusion here is that `v_in`.grad shouldn't be 0 | |
v_in = Variable(torch.Tensor([0.2, 0.1]).view(1, 2, 1, 1), requires_grad=True) | |
f1 = lambda x: x * 2 | |
f2 = nn.Conv2d(2, 1, 1, 1) | |
grad_out = Variable(torch.ones(1,1,1,1)) | |
gradient = autograd.grad(outputs=f2(f1(v_in)), inputs=v_in, | |
grad_outputs=grad_out, | |
create_graph=True, retain_graph=True, | |
only_inputs=True)[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment