Skip to content

Instantly share code, notes, and snippets.

@etienne87
Created March 22, 2018 07:51
Show Gist options
  • Save etienne87/9f903b2b16389f9fe98a18fade6df74b to your computer and use it in GitHub Desktop.
Save etienne87/9f903b2b16389f9fe98a18fade6df74b to your computer and use it in GitHub Desktop.
how to make a bilateral filter using torch
#!/usr/bin/python
# torch_bilateral: bi/trilateral filtering in torch
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import Parameter
import numpy as np
import pdb
import time
def gkern2d(l=21, sig=3):
"""Returns a 2D Gaussian kernel array."""
ax = np.arange(-l // 2 + 1., l // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
return kernel
class Shift(nn.Module):
def __init__(self, in_planes, kernel_size=3):
super(Shift, self).__init__()
self.in_planes = in_planes
self.kernel_size = kernel_size
self.channels_per_group = self.in_planes // (self.kernel_size ** 2)
if self.kernel_size == 3:
self.pad = 1
elif self.kernel_size == 5:
self.pad = 2
elif self.kernel_size == 7:
self.pad = 3
def forward(self, x):
n, c, h, w = x.size()
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
# Alias for convenience
cpg = self.channels_per_group
cat_layers = []
for i in range(self.in_planes):
#Parse in row-major
for y in range(0,self.kernel_size):
y2 = y+h
for x in range(0, self.kernel_size):
x2 = x+w
xx = x_pad[:,i:i+1,y:y2,x:x2]
cat_layers += [xx]
return torch.cat(cat_layers, 1)
class BilateralFilter(nn.Module):
r"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1):
super(BilateralFilter, self).__init__()
#space gaussian kernel
#FIXME: do everything in torch
self.g = Parameter(torch.Tensor(channels,k*k))
self.gw = gkern2d(k,sigma_space)
gw = np.tile(self.gw.reshape(channels,k*k,1,1),(1,1,height,width))
self.g.data = torch.from_numpy(gw).float()
#shift
self.shift = Shift(channels,k)
self.sigma_color = 2*sigma_color**2
def forward(self, I):
Is = self.shift(I).data
Iex = I.expand(*Is.size())
D = (Is-Iex)**2 #here we are actually missing some sum over groups of channels
De = torch.exp(-D / self.sigma_color)
Dd = De * self.g.data
W_denom = torch.sum(Dd,dim=1)
If = torch.sum(Dd*Is,dim=1) / W_denom
return If
if __name__ == '__main__':
import matplotlib.pyplot as plt
import cv2
c,h,w = 1,480/2,640/2
k = 5
cuda = False
bilat = BilateralFilter(c,k,h,w)
if cuda:
bilat.cuda()
im = cv2.imread('/home/eperot/Pictures/Lena.png', cv2.IMREAD_GRAYSCALE)
im = cv2.resize(im,(w,h),interpolation=cv2.INTER_CUBIC)
im_in = im.reshape(1,1,h,w)
img = torch.from_numpy(im_in).float() / 255.0
if cuda:
img_in = img.cuda()
else:
img_in = img
y = bilat(img_in)
start = time.time()
y = bilat(img_in)
print(time.time()-start)
img_out = y.cpu().numpy()[0] #not counting the return transfer in timing!
show_out = cv2.resize(img_out,(640,480))
show_in = cv2.resize(img[0,0].numpy(),(640,480))
# diff = np.abs(img_out - img[0,0])
# diff = (diff - diff.min()) / (diff.max() - diff.min())
# cv2.namedWindow('diff')
# cv2.moveWindow('diff',50,50)
# cv2.imshow('diff', diff)
cv2.namedWindow('kernel')
cv2.imshow('kernel', bilat.gw)
cv2.namedWindow('img_out')
cv2.moveWindow('img_out', 200, 200)
cv2.imshow('img_out',show_out)
#
cv2.namedWindow('img_in')
cv2.imshow('img_in', show_in)
cv2.waitKey(0)
# n = gkern2d(5,3)
#
# s = n.reshape(1,25)
#
# print(n)
# print(s)
# plt.imshow(n, interpolation='none')
# plt.show()
@minh-nguyenhoang
Copy link

minh-nguyenhoang commented Sep 11, 2024

Another note is seperable bilateral filtering (this is not equivalent to normal bilateral filtering, but still produce desirable result and can scale to extremely large kernel size with much smaller computational overhead)

def gkern1d(l=21, sig=3, device='cpu'):
    """Returns a 2D Gaussian kernel array."""
    ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
    kernel = torch.exp(-(ax ** 2) / (2. * sig ** 2)).view(l, 1)
    return kernel / kernel.sum()


class SeperableShift(nn.Module):
    def __init__(self, kernel_size=3):
        super(SeperableShift, self).__init__()
        kernel_size = kernel_size//2 * 2 + 1
        self.kernel_size = kernel_size
        self.pad = kernel_size//2
        self.kernels = nn.Parameter(F.one_hot(torch.arange(0, kernel_size)).view(kernel_size, 1, kernel_size, 1).float())


    def forward(self, x: torch.Tensor, transpose = True):
        B,C,H,W = x.shape
        x_ = x.view(-1, 1, H,W)
        if transpose:
            x_ = F.conv2d(x_, self.kernels, padding= (self.pad, 0))
        else:
            x_ = F.conv2d(x_, self.kernels.transpose(-1,-2), padding= (0, self.pad))
        x = x_.view(B,-1,H,W)
        return x


class SeperableBilateralFilter(nn.Module):
    """BilateralFilter computes:
        If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
    """

    def __init__(self, k=7, sigma_space=5, sigma_color=0.1, device='cpu'):
        super(SeperableBilateralFilter, self).__init__()
        k = k//2 * 2 + 1
        # space gaussian kernel
        self.gw = nn.Parameter(gkern1d(k, sigma_space, device=device).view(1,-1,1,1))
        
        self.k = k
        # shift
        self.shift = SeperableShift(k)
        self.sigma_color = 2*sigma_color**2

        if device is not None:
            self.to(device=device)

    def forward(self, I: torch.Tensor):
        B,C,H,W = I.shape
        Is = self.shift(I, transpose = False).view(-1,self.k,H,W)
        Iex = I.view(-1,1,H,W)
        D = (Is-Iex)**2
        De = torch.exp(-D / self.sigma_color)
        De = De/torch.sum(De*self.gw,1, keepdim= True)
        If = F.conv2d(De*Is, self.gw).view(B,C,H,W)

        I = If
        Is = self.shift(I, transpose = True).view(-1,self.k,H,W)
        Iex = I.view(-1,1,H,W)
        D = (Is-Iex)**2
        De = torch.exp(-D / self.sigma_color)
        De = De/torch.sum(De*self.gw,1, keepdim= True)
        If = F.conv2d(De*Is, self.gw).view(B,C,H,W)
        return If

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment