Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active February 3, 2021 19:41
Show Gist options
  • Save crowsonkb/f76ca87df4d22c1464ac56d4ad86e3a9 to your computer and use it in GitHub Desktop.
Save crowsonkb/f76ca87df4d22c1464ac56d4ad86e3a9 to your computer and use it in GitHub Desktop.
Better image downsampling (factor of 2) in PyTorch
from math import ceil
import torch
from torch import nn
from torch.nn import functional as F
class Downsample2d(nn.Module):
kernels = {
'binomial2': [0.25, 0.5, 0.25],
'lanczos2': [-0.0315, 0, 0.2839, 0.4952, 0.2839, 0, -0.0315],
'lanczos3': [0.0122, 0, -0.0677, 0, 0.3048, 0.5014, 0.3048, 0, -0.0677, 0, 0.0122],
}
def __init__(self, kernel='binomial2', separate=False):
super().__init__()
if isinstance(kernel, str):
kernel = self.kernels[kernel]
kernel = torch.as_tensor(kernel)
assert kernel.ndim == 1
kernel /= kernel.sum()
if not separate:
kernel = kernel[:, None] @ kernel[None, :]
self.register_buffer('kernel', kernel)
def forward(self, input):
n, c, h, w = input.shape
input = input.view([n * c, 1, h, w])
start_pad = (self.kernel.shape[0] - 1) // 2
end_pad = self.kernel.shape[0] // 2
input = F.pad(input, (start_pad, end_pad, start_pad, end_pad), 'reflect')
if self.kernel.ndim == 1:
input = F.conv2d(input, self.kernel[None, None, None, :], stride=(1, 2))
input = F.conv2d(input, self.kernel[None, None, :, None], stride=(2, 1))
else:
input = F.conv2d(input, self.kernel[None, None, :, :], stride=2)
return input.view([n, c, ceil(h/2), ceil(w/2)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment