Skip to content

Instantly share code, notes, and snippets.

@soumith
Created February 27, 2018 09:20
Show Gist options
  • Save soumith/c6dc258fb9bce1afe54cf3d2f61d90e4 to your computer and use it in GitHub Desktop.
Save soumith/c6dc258fb9bce1afe54cf3d2f61d90e4 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from torch.autograd import Variable
torch.manual_seed(1)
# do gradcheck
N = 1
C = 1
D = 1
H = 2
W = 1
input_3d = Variable(torch.ones(N, C, H, W), requires_grad=False)
grid_3d = Variable(torch.randn(N, H, W, 2), requires_grad=True)
grid_3d[:, :, :, 1] = 0
out_3d = nn.functional.grid_sample(input_3d, grid_3d, padding_mode='zeros')
print(out_3d)
out_3d.sum().backward()
input_2d = Variable(input_3d.data.view(N, C, H, W), requires_grad=False)
grid_2d = Variable(grid_3d.data.view(N, H, W, 3)[:, :, (0, 1)].clone(), requires_grad=True)
out_2d = nn.functional.grid_sample(input_2d, grid_2d, padding_mode='zeros')
print(out_2d)
out_2d.sum().backward()
print(grid_3d.grad)
print(grid_2d.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment