b
: Batch sizei
: Input features (Linear)o
: Output featuresc_in
: Input channelsc_out
: Output channelsn
: Input/output length (1D sequence, often number of tokens)n_in
: Input length (1D sequence)n_out
: Output length (1D sequence)h_in
: Input heightw_in
: Input widthh_out
: Output heightw_out
: Output widthk
: Kernel size (1D)kh
: Kernel heightkw
: Kernel width
Linear
layer
X = rearrange(X, 'b c_in -> b i 1')
W = rearrange(W, 'c_out c_in -> 1 i o')
Y = (X * W).sum(dim=1) # + bias
Conv1d
layer
unfolded_X = X.unfold(dimension=2, size=kernel_size, step=stride)
unfolded_X = rearrange(unfolded_X, 'b c_in n_out k -> b 1 c_in n_out k')
W = rearrange(W, 'c_out c_in k1 -> 1 c_out c_in 1 k')
Y = (unfolded_X * W).sum(dim=(2, 4))
Y = (unfolded_X[..., ::dilation] * W).sum(dim=(2, 4))
unfolded_X = rearrange(unfolded_X, 'b (groups c_in) n_out k1 -> b groups c_in n_out k', groups=groups)
W = rearrange(W, '(groups c_out) c_in k -> 1 groups c_out c_in 1 k', groups=groups)
Y = (unfolded_X * W).sum(dim=(3, 5))
Y = rearrange(Y, 'b groups ... -> (b groups) ...')
Conv2d
layer
unfolded_X = X.unfold(2, kernel_size[0], stride).unfold(3, kernel_size[1], stride)
unfolded_X = rearrange(unfolded_X, 'b c_in h_out w_out kh kw -> b 1 c_in h_out w_out kh kw')
W = rearrange(W, 'c_out c_in kh kw -> 1 c_out c_in 1 1 kh kw')
Y = (unfolded_X * W).sum(dim=(2, 5, 6))
Y = (unfolded_X[..., ::dilation, ::dilation] * W).sum(dim=(2, 5, 6))
unfolded_X = rearrange(unfolded_X, 'b (groups c_in) h_out w_out kh kw -> b groups c_in h_out w_out kh kw', groups=groups)
W = rearrange(W, '(groups c_out) c_in kh kw -> 1 groups c_out c_in 1 1 kh kw', groups=groups)
Y = (unfolded_X * W).sum(dim=(3, 6, 7))
Y = rearrange(Y, 'b groups ... -> (b groups) ...')
Side-note: pad
X_padded = nn.functional.pad(X, (left, right, top, bottom))
Assuming that both arguments have the same number of dims and dims ≥ 2:
A = rearrange(A, '... m n -> ... m n 1')
B = rearrange(B, '... n p -> ... 1 n p')
C = (A * B).sum(dim=-2) # ... m p
b. bmm
A = rearrange(A, 'b m n -> b m n 1')
B = rearrange(B, 'b n p -> b 1 n p')
C = (A * B).sum(dim=-2) # b m p
d. mm
A = rearrange(A, 'm n -> m n 1')
B = rearrange(B, 'n p -> 1 n p')
C = (A * B).sum(dim=-2) # m p
# X = rearrange(X, 'b c1 c2 h w -> b (c1 c2) h w')
X = X.view(X.size(0), -1, *X.shape[-2:])
# X = rearrange(X, 'b (c1 c2) h w -> b c1 c2 h w', c1=c1, c2=c2)
X = X.view(X.size(0), c1, c2, *X.shape[-2])
# X = rearrange(X, 'b c h w -> b w h c')
X = X.permute(0, 3, 2, 1)
# X = rearrange(X, 'b i 1 -> b i')
X = X.squeeze(-1)
# X = rearrange(X, 'b i -> b i 1')
X = X.unsqueeze(-1)
X = X[..., None]
https://einops.rocks/pytorch-examples.html