Skip to content

Instantly share code, notes, and snippets.

@pclucas14
Created August 17, 2019 18:44
Show Gist options
  • Save pclucas14/1834b5c437ed08a437be475f27f06f82 to your computer and use it in GitHub Desktop.
Save pclucas14/1834b5c437ed08a437be475f27f06f82 to your computer and use it in GitHub Desktop.
Convolutions with `as_strided` and `einsum`
class einsum_conv(nn.Module):
def __init__(self, kernel_size, stride=1):
super(einsum_conv, self).__init__()
self.ks = kernel_size
self.stride = stride
def forward(self, x, kernel):
if len(x.size()) == 3:
x = x.unsqueeze(0)
assert len(x.size()) == 4, 'need bs x c x h x w format'
bs, in_c, h, w = x.size()
ks = self.ks
strided_x = x.as_strided((bs, in_c, (h - ks) // self.stride + 1, (w - ks) // self.stride + 1, ks, ks),
(in_c * h * w, h * w, self.stride * w, self.stride, w, 1))
out = torch.einsum('bihwkl,oikl->bohw', (strided_x, kernel))
return out
@slvrfn
Copy link

slvrfn commented Feb 19, 2021

very helpful, thank you!

If anyone is trying to use this with numpy's as_strided you will need to include the bs.strides with the stride sizes. The strides specified here are specific to pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment