Last active
July 15, 2020 22:34
-
-
Save redwrasse/bad95645c8a539ff022b7d9bcf6a5551 to your computer and use it in GitHub Desktop.
verifies output length and partial derivatives of standard and left-padded 1d convolutions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| verifies expected output length and partial derivatives of | |
| a) standard 1d convolution | |
| b) left-padded 1d convolution | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| # convolution definition: | |
| # s_i = x_l w_{i-l} | |
| # standard convolution = keep only those outputs s_i | |
| # for which the sum exists for all weights | |
| # | |
| # standard convolution in zero based indexing with kernel size k is | |
| # | |
| # s_i = x_l w_{i + k - l - 1} | |
| # | |
| # convolution of kernel size 3, | |
| # on a 1d input of length 10 | |
| # => standard convolution output | |
| # should be of length n - k + 1 = 8 | |
| # | |
| # partial derivatives | |
| # D_ij := del s_i del x_j = w_{i + k - j} | |
| # | |
| # left-padded convolution same expression | |
| # s_i = x_l w_{i-l} | |
| # but with 'negative index' 0s left-padded to x | |
| # hence | |
| # D_ij := del s_i del x_j = w_{i - j} | |
| # for i, j = 0,..., n - 1 | |
| # | |
| # input length and kernel size | |
| n, k = 10, 3 | |
| def partial_derivative(y, x, i, j): | |
| # computes del y_i del x_j: partial derivative of y_i wrt x_j | |
| # | |
| # torch.autograd.grad gives z_i grad_j y_i, where z_i is a vector fed to 'grad_outputs' | |
| # hence feeding a one-hot as z_i gives the jacobian row grad_j y_i for i fixed | |
| z = torch.zeros(y.shape) | |
| z[:, :, i] = 1. | |
| dy_dx = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=z, retain_graph=True)[0][:, :, j][0] | |
| return dy_dx[0].item() | |
| def left_pad_k(x, m): | |
| # left pad a 1d tensor x with m zeroes | |
| # on the left | |
| return F.pad(x, | |
| pad=(m, 0), | |
| mode='constant', | |
| value=0) | |
| def test_standard_convolution(): | |
| x = torch.randn(1, 1, n) | |
| x.requires_grad = True | |
| layer = torch.nn.Conv1d( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=k | |
| ) | |
| layer.weight = torch.nn.Parameter(torch.ones_like(layer.weight)) | |
| y = layer(x) | |
| assert (y.shape[2] == 8), 'standard convolution output should be of length n - k + 1' | |
| for i in range(n-k+1): | |
| for j in range(n): | |
| # nonzero values should be for which w_{i + k - j -1} | |
| # is nonzero, eg i + k - j - 1 = 0, ..., k - 1 (see notes at top) | |
| pd = partial_derivative(y, x, i, j) | |
| in_range = (i - j + k - 1>= 0) and (i - j + k - 1 <= k - 1) | |
| if in_range: | |
| assert (pd == 1.0), 'unexpected standard convolution partial derivative' | |
| else: | |
| assert (pd == 0.0), 'unexpected standard convolution partial derivative' | |
| def test_left_padded_convolution(): | |
| x = torch.randn(1, 1, n) | |
| x.requires_grad = True | |
| # standard left-padding is with k - 1 zeros to produce matched output length | |
| x2 = left_pad_k(x, k-1) | |
| layer = torch.nn.Conv1d( | |
| in_channels=1, | |
| out_channels=1, | |
| kernel_size=k | |
| ) | |
| layer.weight = torch.nn.Parameter(torch.ones_like(layer.weight)) | |
| y = layer(x2) | |
| assert (y.shape[2] == n), 'left-padded convolution output should be of length n' | |
| for i in range(n - k + 1): | |
| for j in range(n): | |
| # nonzero values should be for which w_{i -j} | |
| # is nonzero, eg i - j = 0, ..., k - 1 (see notes at top) | |
| pd = partial_derivative(y, x, i, j) | |
| in_range = (i - j >= 0) and (i - j <= k - 1) | |
| if in_range: | |
| assert (pd == 1.0), 'unexpected left-padded convolution partial derivative' | |
| else: | |
| assert (pd == 0.0), 'unexpected left-padded convolution partial derivative' | |
| def tests(): | |
| test_standard_convolution() | |
| test_left_padded_convolution() | |
| tests() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment