Skip to content

Instantly share code, notes, and snippets.

@etienne87
Created March 22, 2018 07:51
Show Gist options
  • Save etienne87/9f903b2b16389f9fe98a18fade6df74b to your computer and use it in GitHub Desktop.
Save etienne87/9f903b2b16389f9fe98a18fade6df74b to your computer and use it in GitHub Desktop.
how to make a bilateral filter using torch
#!/usr/bin/python
# torch_bilateral: bi/trilateral filtering in torch
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import Parameter
import numpy as np
import pdb
import time
def gkern2d(l=21, sig=3):
"""Returns a 2D Gaussian kernel array."""
ax = np.arange(-l // 2 + 1., l // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
return kernel
class Shift(nn.Module):
def __init__(self, in_planes, kernel_size=3):
super(Shift, self).__init__()
self.in_planes = in_planes
self.kernel_size = kernel_size
self.channels_per_group = self.in_planes // (self.kernel_size ** 2)
if self.kernel_size == 3:
self.pad = 1
elif self.kernel_size == 5:
self.pad = 2
elif self.kernel_size == 7:
self.pad = 3
def forward(self, x):
n, c, h, w = x.size()
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
# Alias for convenience
cpg = self.channels_per_group
cat_layers = []
for i in range(self.in_planes):
#Parse in row-major
for y in range(0,self.kernel_size):
y2 = y+h
for x in range(0, self.kernel_size):
x2 = x+w
xx = x_pad[:,i:i+1,y:y2,x:x2]
cat_layers += [xx]
return torch.cat(cat_layers, 1)
class BilateralFilter(nn.Module):
r"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1):
super(BilateralFilter, self).__init__()
#space gaussian kernel
#FIXME: do everything in torch
self.g = Parameter(torch.Tensor(channels,k*k))
self.gw = gkern2d(k,sigma_space)
gw = np.tile(self.gw.reshape(channels,k*k,1,1),(1,1,height,width))
self.g.data = torch.from_numpy(gw).float()
#shift
self.shift = Shift(channels,k)
self.sigma_color = 2*sigma_color**2
def forward(self, I):
Is = self.shift(I).data
Iex = I.expand(*Is.size())
D = (Is-Iex)**2 #here we are actually missing some sum over groups of channels
De = torch.exp(-D / self.sigma_color)
Dd = De * self.g.data
W_denom = torch.sum(Dd,dim=1)
If = torch.sum(Dd*Is,dim=1) / W_denom
return If
if __name__ == '__main__':
import matplotlib.pyplot as plt
import cv2
c,h,w = 1,480/2,640/2
k = 5
cuda = False
bilat = BilateralFilter(c,k,h,w)
if cuda:
bilat.cuda()
im = cv2.imread('/home/eperot/Pictures/Lena.png', cv2.IMREAD_GRAYSCALE)
im = cv2.resize(im,(w,h),interpolation=cv2.INTER_CUBIC)
im_in = im.reshape(1,1,h,w)
img = torch.from_numpy(im_in).float() / 255.0
if cuda:
img_in = img.cuda()
else:
img_in = img
y = bilat(img_in)
start = time.time()
y = bilat(img_in)
print(time.time()-start)
img_out = y.cpu().numpy()[0] #not counting the return transfer in timing!
show_out = cv2.resize(img_out,(640,480))
show_in = cv2.resize(img[0,0].numpy(),(640,480))
# diff = np.abs(img_out - img[0,0])
# diff = (diff - diff.min()) / (diff.max() - diff.min())
# cv2.namedWindow('diff')
# cv2.moveWindow('diff',50,50)
# cv2.imshow('diff', diff)
cv2.namedWindow('kernel')
cv2.imshow('kernel', bilat.gw)
cv2.namedWindow('img_out')
cv2.moveWindow('img_out', 200, 200)
cv2.imshow('img_out',show_out)
#
cv2.namedWindow('img_in')
cv2.imshow('img_in', show_in)
cv2.waitKey(0)
# n = gkern2d(5,3)
#
# s = n.reshape(1,25)
#
# print(n)
# print(s)
# plt.imshow(n, interpolation='none')
# plt.show()
@minh-nguyenhoang
Copy link

I've made some changes to the layer:

  • Making shift operator parallel using convolution sampling (which is proved to be better optimized than unfold)
  • Redesign the Bilateral Filter using convolution for optimization (Unfortunately this still does not solve the problem of memory inefficiency)
  • Making this layer channel-agnostic (it should work with any kind of image)
  • Remove some really unneccesary inputs like image width/height and channel
import torch
from torch import nn
import torch.nn.functional as F


def gkern2d(l=21, sig=3, device='cpu'):
    """Returns a 2D Gaussian kernel array."""
    ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
    xx, yy = torch.meshgrid(ax, ax)
    kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
    return kernel / kernel.sum() # Early normalization


class Shift(nn.Module):
    def __init__(self, kernel_size=3):
        super(Shift, self).__init__()
        kernel_size = kernel_size//2 * 2 + 1
        self.kernel_size = kernel_size
        self.pad = kernel_size//2
        self.kernels = nn.Parameter(F.one_hot(torch.arange(0, kernel_size*kernel_size)).view(kernel_size*kernel_size, 1, kernel_size, kernel_size).float())

    def forward(self, x: torch.Tensor):
        B,C,H,W = x.shape
        x_ = x.view(-1, 1, H,W)
        x_ = F.conv2d(x_, self.kernels, padding= self.pad)
        x = x_.view(B,-1,H,W)
        return x

class BilateralFilter(nn.Module):
    """BilateralFilter computes:
        If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
    """

    def __init__(self, k=7, sigma_space=5, sigma_color=0.1, device='cpu'):
        super(BilateralFilter, self).__init__()
        k = k//2 * 2 + 1
        # space gaussian kernel
        self.gw = nn.Parameter(gkern2d(k, sigma_space, device=device).view(1,-1,1,1))
        
        self.k = k
        # shift
        self.shift = Shift(k)
        self.sigma_color = 2*sigma_color**2

        if device is not None:
            self.to(device=device)

    def forward(self, I: torch.Tensor):
        B,C,H,W = I.shape
        Is = self.shift(I).view(-1,self.k*self.k,H,W)
        Iex = I.view(-1,1,H,W)
        D = (Is-Iex)**2 # here we are actually missing some sum over groups of channels
        De = torch.exp(-D / self.sigma_color)
        De = De/torch.sum(De*self.gw,1, keepdim= True)
        If = F.conv2d(De*Is, self.gw)
        return If.view(B,C,H,W)


if __name__ == '__main__':
    from skimage import data as ski_data
    from skimage.color import rgb2gray
    import time
    import matplotlib.pyplot as plt
    import numpy as np
    k = 5
    iter = 3
    device = 'cuda'
    
    img_ = ski_data.astronaut()
    # img_ = rgb2gray(img_)
    # img = img_[None, None, ...]
    # img = torch.tensor(img).float()

    img = torch.tensor(img_).unsqueeze(0).permute(0,3,1,2).float()/255
    
    if device:
        img = img.cuda()
    fig, ax = plt.subplots(1 + (iter)//4,min(1+iter, 4), figsize = (5*min(1+iter, 4),5*(1 + (iter)//4)), squeeze = False)
    ax[0,0].imshow(img_, cmap = 'gray')

    bilat = BilateralFilter(k, sigma_color=0.1, device = device)

    start_time = time.time()
    with torch.no_grad():
        for i in range(iter):
            img_filtered = bilat(img)
            img = img_filtered
            print("Duration: ", time.time() - start_time)
            img_out = img_filtered[0].cpu().permute(1,2,0).numpy().squeeze()
            ax[(1+i)//4, (1+i)%4].imshow(img_out, cmap = 'gray')
    
    plt.show()

@minh-nguyenhoang
Copy link

minh-nguyenhoang commented Sep 11, 2024

Another note is seperable bilateral filtering (this is not equivalent to normal bilateral filtering, but still produce desirable result and can scale to extremely large kernel size with much smaller computational overhead)

def gkern1d(l=21, sig=3, device='cpu'):
    """Returns a 2D Gaussian kernel array."""
    ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
    kernel = torch.exp(-(ax ** 2) / (2. * sig ** 2)).view(l, 1)
    return kernel / kernel.sum()


class SeperableShift(nn.Module):
    def __init__(self, kernel_size=3):
        super(SeperableShift, self).__init__()
        kernel_size = kernel_size//2 * 2 + 1
        self.kernel_size = kernel_size
        self.pad = kernel_size//2
        self.kernels = nn.Parameter(F.one_hot(torch.arange(0, kernel_size)).view(kernel_size, 1, kernel_size, 1).float())


    def forward(self, x: torch.Tensor, transpose = True):
        B,C,H,W = x.shape
        x_ = x.view(-1, 1, H,W)
        if transpose:
            x_ = F.conv2d(x_, self.kernels, padding= (self.pad, 0))
        else:
            x_ = F.conv2d(x_, self.kernels.transpose(-1,-2), padding= (0, self.pad))
        x = x_.view(B,-1,H,W)
        return x


class SeperableBilateralFilter(nn.Module):
    """BilateralFilter computes:
        If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
    """

    def __init__(self, k=7, sigma_space=5, sigma_color=0.1, device='cpu'):
        super(SeperableBilateralFilter, self).__init__()
        k = k//2 * 2 + 1
        # space gaussian kernel
        self.gw = nn.Parameter(gkern1d(k, sigma_space, device=device).view(1,-1,1,1))
        
        self.k = k
        # shift
        self.shift = SeperableShift(k)
        self.sigma_color = 2*sigma_color**2

        if device is not None:
            self.to(device=device)

    def forward(self, I: torch.Tensor):
        B,C,H,W = I.shape
        Is = self.shift(I, transpose = False).view(-1,self.k,H,W)
        Iex = I.view(-1,1,H,W)
        D = (Is-Iex)**2
        De = torch.exp(-D / self.sigma_color)
        De = De/torch.sum(De*self.gw,1, keepdim= True)
        If = F.conv2d(De*Is, self.gw).view(B,C,H,W)

        I = If
        Is = self.shift(I, transpose = True).view(-1,self.k,H,W)
        Iex = I.view(-1,1,H,W)
        D = (Is-Iex)**2
        De = torch.exp(-D / self.sigma_color)
        De = De/torch.sum(De*self.gw,1, keepdim= True)
        If = F.conv2d(De*Is, self.gw).view(B,C,H,W)
        return If

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment