Created
October 2, 2022 21:32
-
-
Save bwasti/519f71e05684cb963ace913dd4c9ec07 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# conv bwd implemented with fwd functions | |
import torch | |
import torch.nn.functional as F | |
def dconv2d(grad, x, w, stride, padding, groups): | |
batch = grad.shape[0] | |
channel_out = grad.shape[1] | |
channel_in = x.shape[1] | |
k = w.shape[-1] | |
# differentiating w.r.t x | |
gpad = (k - 1) - (stride - 1) - padding | |
dxgrad = grad | |
if stride > 1: # manually dilate the incoming gradient | |
dxgrad = dxgrad.reshape(*dxgrad.shape, 1, 1) | |
dxgrad = F.pad(dxgrad, (stride - 1, 0, stride - 1, 0)).transpose(3, 4) | |
dxgrad = dxgrad.reshape(*grad.shape[:-2], *[2 * d for d in grad.shape[-2:]]) | |
dxgrad = F.pad(dxgrad, (0, stride - 1, 0, stride - 1)) | |
dxw = w.flip([2, 3]) | |
if groups > 1: # transpose within the groups | |
dxw = dxw.reshape(groups, dxw.shape[0] // groups, *dxw.shape[1:]) | |
dxw = dxw.transpose(1, 2) | |
dxw = dxw.reshape(-1, *dxw.shape[2:]) | |
else: | |
dxw = dxw.transpose(0, 1) | |
dx = torch.conv2d(dxgrad, dxw, padding=gpad, groups=groups) | |
# differentiating w.r.t w | |
dwgrad = grad.transpose(0, 1) | |
if groups > 1: | |
dwx = x.reshape(x.shape[0], groups, x.shape[1] // groups, *x.shape[2:]) | |
dwx = dwx.transpose(0, 2) | |
dwx = dwx.reshape(dwx.shape[0], -1, *dwx.shape[3:]) | |
else: | |
dwx = x.transpose(0, 1) | |
dw = torch.conv2d(dwx, dwgrad, padding=padding, dilation=stride, groups=groups) | |
dw = dw.transpose(0, 1) | |
return dx, dw | |
def simple(): | |
print("simple") | |
x = torch.randn(1, 1, 4, 4) | |
x.requires_grad = True | |
w = torch.randn(1, 1, 3, 3) | |
w.requires_grad = True | |
grad = torch.randn(1, 1, 2, 2) | |
y = torch.conv2d(x, w) | |
y.backward(grad) | |
dx, dw = dconv2d(grad, x, w, 1, 0, 1) | |
torch.testing.assert_close(x.grad, dx) | |
torch.testing.assert_close(w.grad, dw) | |
print("pass") | |
def padded(): | |
print("padded") | |
x = torch.randn(1, 1, 4, 4) | |
x.requires_grad = True | |
w = torch.randn(1, 1, 3, 3) | |
w.requires_grad = True | |
y = torch.conv2d(x, w, padding=1) | |
grad = torch.randn(1, 1, 4, 4) | |
y.backward(grad) | |
dx, dw = dconv2d(grad, x, w, 1, 1, 1) | |
torch.testing.assert_close(x.grad, dx) | |
torch.testing.assert_close(w.grad, dw) | |
print("pass") | |
def strided(): | |
print("strided") | |
x = torch.randn(1, 1, 5, 5) | |
x.requires_grad = True | |
w = torch.randn(1, 1, 3, 3) | |
w.requires_grad = True | |
y = torch.conv2d(x, w, stride=2) | |
grad = torch.randn(1, 1, 2, 2) | |
y.backward(grad) | |
dx, dw = dconv2d(grad, x, w, 2, 0, 1) | |
torch.testing.assert_close(x.grad, dx) | |
torch.testing.assert_close(w.grad, dw) | |
print("pass") | |
def strided_padded(): | |
print("strided/padded") | |
x = torch.randn(8, 2, 5, 5) | |
x.requires_grad = True | |
w = torch.randn(4, 2, 3, 3) | |
w.requires_grad = True | |
y = torch.conv2d(x, w, stride=2, padding=1) | |
grad = torch.randn(8, 4, 3, 3) | |
y.backward(grad) | |
dx, dw = dconv2d(grad, x, w, 2, 1, 1) | |
torch.testing.assert_close(x.grad, dx) | |
torch.testing.assert_close(w.grad, dw) | |
print("pass") | |
def strided_padded_grouped(): | |
print("strided/padded/grouped") | |
x = torch.randn(7, 4, 5, 5) | |
x.requires_grad = True | |
w = torch.randn(6, 2, 3, 3) | |
w.requires_grad = True | |
y = torch.conv2d(x, w, stride=2, padding=1, groups=2) | |
grad = torch.randn(7, 6, 3, 3) | |
y.backward(grad) | |
dx, dw = dconv2d(grad, x, w, 2, 1, 2) | |
torch.testing.assert_close(x.grad, dx) | |
torch.testing.assert_close(w.grad, dw) | |
print("pass") | |
simple() | |
padded() | |
strided() | |
strided_padded() | |
strided_padded_grouped() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment