Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created February 22, 2021 18:08
Show Gist options
  • Save crowsonkb/f618af622d99a5e4dc9715f856f86cbd to your computer and use it in GitHub Desktop.
Save crowsonkb/f618af622d99a5e4dc9715f856f86cbd to your computer and use it in GitHub Desktop.
Binomial2Pool2d
import torch
from torch import nn
from torch.nn import functional as F
class Binomial2Pool2d(nn.Module):
def __init__(self, ceil_mode=False):
super().__init__()
self.ceil_mode = ceil_mode
kernel = [[[[1/16, 1/8, 1/16], [1/8, 1/4, 1/8], [1/16, 1/8, 1/16]]]]
self.register_buffer('kernel', torch.tensor(kernel), persistent=False)
def forward(self, input):
n, c, h, w = input.shape
input = input.view([n * c, 1, h, w])
pad_h = 1 if self.ceil_mode else 1 - h % 2
pad_w = 1 if self.ceil_mode else 1 - w % 2
input = F.pad(input, (1, pad_w, 1, pad_h), 'reflect')
input = F.conv2d(input, self.kernel, stride=2)
return input.view([n, c, input.shape[2], input.shape[3]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment