Created
August 17, 2019 18:44
-
-
Save pclucas14/1834b5c437ed08a437be475f27f06f82 to your computer and use it in GitHub Desktop.
Convolutions with `as_strided` and `einsum`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class einsum_conv(nn.Module): | |
def __init__(self, kernel_size, stride=1): | |
super(einsum_conv, self).__init__() | |
self.ks = kernel_size | |
self.stride = stride | |
def forward(self, x, kernel): | |
if len(x.size()) == 3: | |
x = x.unsqueeze(0) | |
assert len(x.size()) == 4, 'need bs x c x h x w format' | |
bs, in_c, h, w = x.size() | |
ks = self.ks | |
strided_x = x.as_strided((bs, in_c, (h - ks) // self.stride + 1, (w - ks) // self.stride + 1, ks, ks), | |
(in_c * h * w, h * w, self.stride * w, self.stride, w, 1)) | |
out = torch.einsum('bihwkl,oikl->bohw', (strided_x, kernel)) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
very helpful, thank you!
If anyone is trying to use this with numpy's
as_strided
you will need to include thebs.strides
with the stride sizes. The strides specified here are specific to pytorch