Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created February 7, 2021 12:21
Show Gist options
  • Save crowsonkb/6b4f089bb0139fc28403c5ca7eb74dc6 to your computer and use it in GitHub Desktop.
Save crowsonkb/6b4f089bb0139fc28403c5ca7eb74dc6 to your computer and use it in GitHub Desktop.
Applies a 2D soft pooling over an input signal composed of several input planes. See https://arxiv.org/abs/2101.00440
from torch import nn
from torch.nn import functional as F
class SoftPool2d(nn.Module):
"""Applies a 2D soft pooling over an input signal composed of several
input planes. See https://arxiv.org/abs/2101.00440"""
def __init__(self, kernel_size, ceil_mode=False, temperature=1.):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self.kernel_size = kernel_size
self.ceil_mode = ceil_mode
self.temperature = temperature
def extra_repr(self):
return f'kernel_size={self.kernel_size}, ' \
f'ceil_mode={self.ceil_mode}, ' \
f'temperature={self.temperature:g}'
def forward(self, input):
kh, kw = self.kernel_size
h, w = input.shape[2:]
pad_h = h % kh
pad_w = w % kw
if self.ceil_mode:
input = F.pad(input, (0, pad_w, 0, pad_h))
else:
input = input[..., :h-pad_h, :w-pad_w]
n, c, h, w = input.shape
input = input.view([n, c, h//kh, kh, w//kw, kw]).movedim(3, 4).flatten(4)
input = input * F.softmax(input / self.temperature, dim=4)
return input.sum(dim=4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment