Last active
February 28, 2022 14:25
-
-
Save torridgristle/68ce54f562ff46fd0bbb0381ea4ff243 to your computer and use it in GitHub Desktop.
Kaiser Filter Lowpass Module for PyTorch. Torchvision's gaussian blur uses the "reflect" padding mode but I'm not sure if that makes sense so I've set it for "replicate" by default.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class KaiserLowpass(nn.Module): | |
def __init__(self, width=7, beta=11, periodic=False, padding_mode='replicate'): | |
super().__init__() | |
self.padding_mode = padding_mode | |
self.padding = 4*[(width-1)//2] | |
self.kernel = torch.kaiser_window(width,periodic,beta).reshape(1,1,1,width).to(device) | |
def forward(self, x): | |
b,c,h,w = x.shape | |
x = F.pad(x,self.padding,self.padding_mode) | |
x = F.conv2d(x, self.kernel.expand(c,-1,-1,-1), groups=c) | |
x = F.conv2d(x, self.kernel.permute(0,1,3,2).expand(c,-1,-1,-1), groups=c) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment