b: Batch sizei: Input features (Linear)o: Output featuresc_in: Input channelsc_out: Output channelsn: Input/output length (1D sequence, often number of tokens)n_in: Input length (1D sequence)n_out: Output length (1D sequence)h_in: Input heightw_in: Input widthh_out: Output heightw_out: Output widthk: Kernel size (1D)kh: Kernel heightkw: Kernel width
Linear layer
X = rearrange(X, 'b c_in -> b i 1')
W = rearrange(W, 'c_out c_in -> 1 i o')
Y = (X * W).sum(dim=1) # + biasConv1d layer
unfolded_X = X.unfold(dimension=2, size=kernel_size, step=stride)
unfolded_X = rearrange(unfolded_X, 'b c_in n_out k -> b 1 c_in n_out k')
W = rearrange(W, 'c_out c_in k1 -> 1 c_out c_in 1 k')
Y = (unfolded_X * W).sum(dim=(2, 4))Y = (unfolded_X[..., ::dilation] * W).sum(dim=(2, 4))unfolded_X = rearrange(unfolded_X, 'b (groups c_in) n_out k1 -> b groups c_in n_out k', groups=groups)
W = rearrange(W, '(groups c_out) c_in k -> 1 groups c_out c_in 1 k', groups=groups)
Y = (unfolded_X * W).sum(dim=(3, 5))
Y = rearrange(Y, 'b groups ... -> (b groups) ...')Conv2d layer
unfolded_X = X.unfold(2, kernel_size[0], stride).unfold(3, kernel_size[1], stride)
unfolded_X = rearrange(unfolded_X, 'b c_in h_out w_out kh kw -> b 1 c_in h_out w_out kh kw')
W = rearrange(W, 'c_out c_in kh kw -> 1 c_out c_in 1 1 kh kw')
Y = (unfolded_X * W).sum(dim=(2, 5, 6))Y = (unfolded_X[..., ::dilation, ::dilation] * W).sum(dim=(2, 5, 6))unfolded_X = rearrange(unfolded_X, 'b (groups c_in) h_out w_out kh kw -> b groups c_in h_out w_out kh kw', groups=groups)
W = rearrange(W, '(groups c_out) c_in kh kw -> 1 groups c_out c_in 1 1 kh kw', groups=groups)
Y = (unfolded_X * W).sum(dim=(3, 6, 7))
Y = rearrange(Y, 'b groups ... -> (b groups) ...')Side-note: pad
X_padded = nn.functional.pad(X, (left, right, top, bottom))Assuming that both arguments have the same number of dims and dims ≥ 2:
A = rearrange(A, '... m n -> ... m n 1')
B = rearrange(B, '... n p -> ... 1 n p')
C = (A * B).sum(dim=-2) # ... m pb. bmm
A = rearrange(A, 'b m n -> b m n 1')
B = rearrange(B, 'b n p -> b 1 n p')
C = (A * B).sum(dim=-2) # b m pd. mm
A = rearrange(A, 'm n -> m n 1')
B = rearrange(B, 'n p -> 1 n p')
C = (A * B).sum(dim=-2) # m p# X = rearrange(X, 'b c1 c2 h w -> b (c1 c2) h w')
X = X.view(X.size(0), -1, *X.shape[-2:])# X = rearrange(X, 'b (c1 c2) h w -> b c1 c2 h w', c1=c1, c2=c2)
X = X.view(X.size(0), c1, c2, *X.shape[-2])# X = rearrange(X, 'b c h w -> b w h c')
X = X.permute(0, 3, 2, 1)# X = rearrange(X, 'b i 1 -> b i')
X = X.squeeze(-1)# X = rearrange(X, 'b i -> b i 1')
X = X.unsqueeze(-1)
X = X[..., None]
TODO (WIP / incomplete / wrong / unrefined)
Depthwise separable
ConvTranspose2dwithkernel_size=(2,2)andgroups=4*...Can be rewritten as
Conv2dfollowed byPixelShuffle:More efficient dilation
Simple Attention
Multihead Attention Mechanism