Created
February 7, 2021 12:21
-
-
Save crowsonkb/6b4f089bb0139fc28403c5ca7eb74dc6 to your computer and use it in GitHub Desktop.
Applies a 2D soft pooling over an input signal composed of several input planes. See https://arxiv.org/abs/2101.00440
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
from torch.nn import functional as F | |
class SoftPool2d(nn.Module): | |
"""Applies a 2D soft pooling over an input signal composed of several | |
input planes. See https://arxiv.org/abs/2101.00440""" | |
def __init__(self, kernel_size, ceil_mode=False, temperature=1.): | |
super().__init__() | |
if isinstance(kernel_size, int): | |
kernel_size = (kernel_size, kernel_size) | |
self.kernel_size = kernel_size | |
self.ceil_mode = ceil_mode | |
self.temperature = temperature | |
def extra_repr(self): | |
return f'kernel_size={self.kernel_size}, ' \ | |
f'ceil_mode={self.ceil_mode}, ' \ | |
f'temperature={self.temperature:g}' | |
def forward(self, input): | |
kh, kw = self.kernel_size | |
h, w = input.shape[2:] | |
pad_h = h % kh | |
pad_w = w % kw | |
if self.ceil_mode: | |
input = F.pad(input, (0, pad_w, 0, pad_h)) | |
else: | |
input = input[..., :h-pad_h, :w-pad_w] | |
n, c, h, w = input.shape | |
input = input.view([n, c, h//kh, kh, w//kw, kw]).movedim(3, 4).flatten(4) | |
input = input * F.softmax(input / self.temperature, dim=4) | |
return input.sum(dim=4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment