Skip to content

Instantly share code, notes, and snippets.

@etienne87
Created March 22, 2018 07:51
Show Gist options
  • Save etienne87/9f903b2b16389f9fe98a18fade6df74b to your computer and use it in GitHub Desktop.
Save etienne87/9f903b2b16389f9fe98a18fade6df74b to your computer and use it in GitHub Desktop.
how to make a bilateral filter using torch
#!/usr/bin/python
# torch_bilateral: bi/trilateral filtering in torch
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import Parameter
import numpy as np
import pdb
import time
def gkern2d(l=21, sig=3):
"""Returns a 2D Gaussian kernel array."""
ax = np.arange(-l // 2 + 1., l // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
return kernel
class Shift(nn.Module):
def __init__(self, in_planes, kernel_size=3):
super(Shift, self).__init__()
self.in_planes = in_planes
self.kernel_size = kernel_size
self.channels_per_group = self.in_planes // (self.kernel_size ** 2)
if self.kernel_size == 3:
self.pad = 1
elif self.kernel_size == 5:
self.pad = 2
elif self.kernel_size == 7:
self.pad = 3
def forward(self, x):
n, c, h, w = x.size()
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
# Alias for convenience
cpg = self.channels_per_group
cat_layers = []
for i in range(self.in_planes):
#Parse in row-major
for y in range(0,self.kernel_size):
y2 = y+h
for x in range(0, self.kernel_size):
x2 = x+w
xx = x_pad[:,i:i+1,y:y2,x:x2]
cat_layers += [xx]
return torch.cat(cat_layers, 1)
class BilateralFilter(nn.Module):
r"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1):
super(BilateralFilter, self).__init__()
#space gaussian kernel
#FIXME: do everything in torch
self.g = Parameter(torch.Tensor(channels,k*k))
self.gw = gkern2d(k,sigma_space)
gw = np.tile(self.gw.reshape(channels,k*k,1,1),(1,1,height,width))
self.g.data = torch.from_numpy(gw).float()
#shift
self.shift = Shift(channels,k)
self.sigma_color = 2*sigma_color**2
def forward(self, I):
Is = self.shift(I).data
Iex = I.expand(*Is.size())
D = (Is-Iex)**2 #here we are actually missing some sum over groups of channels
De = torch.exp(-D / self.sigma_color)
Dd = De * self.g.data
W_denom = torch.sum(Dd,dim=1)
If = torch.sum(Dd*Is,dim=1) / W_denom
return If
if __name__ == '__main__':
import matplotlib.pyplot as plt
import cv2
c,h,w = 1,480/2,640/2
k = 5
cuda = False
bilat = BilateralFilter(c,k,h,w)
if cuda:
bilat.cuda()
im = cv2.imread('/home/eperot/Pictures/Lena.png', cv2.IMREAD_GRAYSCALE)
im = cv2.resize(im,(w,h),interpolation=cv2.INTER_CUBIC)
im_in = im.reshape(1,1,h,w)
img = torch.from_numpy(im_in).float() / 255.0
if cuda:
img_in = img.cuda()
else:
img_in = img
y = bilat(img_in)
start = time.time()
y = bilat(img_in)
print(time.time()-start)
img_out = y.cpu().numpy()[0] #not counting the return transfer in timing!
show_out = cv2.resize(img_out,(640,480))
show_in = cv2.resize(img[0,0].numpy(),(640,480))
# diff = np.abs(img_out - img[0,0])
# diff = (diff - diff.min()) / (diff.max() - diff.min())
# cv2.namedWindow('diff')
# cv2.moveWindow('diff',50,50)
# cv2.imshow('diff', diff)
cv2.namedWindow('kernel')
cv2.imshow('kernel', bilat.gw)
cv2.namedWindow('img_out')
cv2.moveWindow('img_out', 200, 200)
cv2.imshow('img_out',show_out)
#
cv2.namedWindow('img_in')
cv2.imshow('img_in', show_in)
cv2.waitKey(0)
# n = gkern2d(5,3)
#
# s = n.reshape(1,25)
#
# print(n)
# print(s)
# plt.imshow(n, interpolation='none')
# plt.show()
@etienne87
Copy link
Author

40 ms cpu (3x slower than opencv), 0.7 ms gpu

@QiuJueqin
Copy link

QiuJueqin commented Mar 13, 2020

Brilliant!

@etienne87
Copy link
Author

thanks :) is this code useful for you?

@Karol-G
Copy link

Karol-G commented Nov 17, 2023

I created an slightly updated version of this gist (everything is torch, no opencv, mitigated use of Parameter and Variable). The images look as expected and everything looks good. @etienne87 Do you agree?

import torch
from torch import nn
import torch.nn.functional as F


def gkern2d(l=21, sig=3, device='cpu'):
    """Returns a 2D Gaussian kernel array."""
    ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
    xx, yy = torch.meshgrid(ax, ax)
    kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
    return kernel


class Shift(nn.Module):
    def __init__(self, in_planes, kernel_size=3):
        super(Shift, self).__init__()
        self.in_planes = in_planes
        self.kernel_size = kernel_size
        self.channels_per_group = self.in_planes // (self.kernel_size ** 2)
        if self.kernel_size == 3:
            self.pad = 1
        elif self.kernel_size == 5:
            self.pad = 2
        elif self.kernel_size == 7:
            self.pad = 3

    def forward(self, x):
        n, c, h, w = x.size()
        x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
        # Alias for convenience
        cpg = self.channels_per_group
        cat_layers = []
        for i in range(self.in_planes):
            # Parse in row-major
            for y in range(0,self.kernel_size):
                y2 = y+h
                for x in range(0, self.kernel_size):
                    x2 = x+w
                    xx = x_pad[:,i:i+1,y:y2,x:x2]
                    cat_layers += [xx]
        return torch.cat(cat_layers, 1)


class BilateralFilter(nn.Module):
    """BilateralFilter computes:
        If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
    """

    def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1, device='cpu'):
        super(BilateralFilter, self).__init__()

        # space gaussian kernel
        self.gw = gkern2d(k, sigma_space, device=device)

        self.g = torch.tile(self.gw.reshape(channels, k*k, 1, 1), (1, 1, height, width))
        # shift
        self.shift = Shift(channels, k)
        self.sigma_color = 2*sigma_color**2

        self.to(device=device)

    def forward(self, I):
        Is = self.shift(I).data
        Iex = I.expand(*Is.size())
        D = (Is-Iex)**2 # here we are actually missing some sum over groups of channels
        De = torch.exp(-D / self.sigma_color)
        Dd = De * self.g
        W_denom = torch.sum(Dd, dim=1)
        If = torch.sum(Dd*Is, dim=1) / W_denom
        return If


if __name__ == '__main__':
    from skimage import data as ski_data
    from skimage.color import rgb2gray
    import time

    k = 5
    device = 'cuda'

    img = ski_data.astronaut()
    img = rgb2gray(img)
    img = img[None, None, ...]
    img = torch.tensor(img)
    if device:
        img = img.cuda()

    bilat = BilateralFilter(img.shape[1], k, img.shape[2], img.shape[3], device=device)

    start_time = time.time()
    img_filtered = bilat(img)
    print("Duration: ", time.time() - start_time)

    img_out = img_filtered.cpu().numpy().squeeze()

@etienne87
Copy link
Author

etienne87 commented Nov 17, 2023

looks fine :)
coming back to this thing several years later i'm wondering if pytorch's unfold is perhaps simpler than this "Shift" operator

@sjames40
Copy link

how to change it for the rgb image case?

@Karol-G
Copy link

Karol-G commented Nov 21, 2023

You would change the channel parameter to channels=3. However, I noticed that it will then throw an exception in both the original version and my modified version. Sadly, I have no understanding of how a bilateral filter works and am therefore unable to fix this error. I needed it myself only to implement a baseline.

@TCQian
Copy link

TCQian commented Aug 8, 2024

anyone has any idea on changing it for rgb image case?

@jeongukkim
Copy link

maybe it works for rgb image

import torch
from torch import nn
import torch.nn.functional as F


def gkern2d(l=21, sig=3, device='cpu'):
    """Returns a 2D Gaussian kernel array."""
    ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
    xx, yy = torch.meshgrid(ax, ax)
    kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
    return kernel


class Shift(nn.Module):
    def __init__(self, in_planes, kernel_size=3):
        super(Shift, self).__init__()
        self.in_planes = in_planes
        self.kernel_size = kernel_size
        self.channels_per_group = self.in_planes // (self.kernel_size ** 2)
        if self.kernel_size == 3:
            self.pad = 1
        elif self.kernel_size == 5:
            self.pad = 2
        elif self.kernel_size == 7:
            self.pad = 3

    def forward(self, x):
        n, c, h, w = x.size()
        x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
        # Alias for convenience
        cpg = self.channels_per_group
        cat_layers = []
        for i in range(self.in_planes):
            # Parse in row-major
            for y in range(0,self.kernel_size):
                y2 = y+h
                for x in range(0, self.kernel_size):
                    x2 = x+w
                    xx = x_pad[:,i:i+1,y:y2,x:x2]
                    cat_layers += [xx]
        return torch.cat(cat_layers, 1)


class BilateralFilter(nn.Module):
    """BilateralFilter computes:
        If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
    """

    def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1, device='cpu'):
        super(BilateralFilter, self).__init__()

        self.k = k
        # space gaussian kernel
        self.gw = gkern2d(k, sigma_space, device=device)

        self.g = torch.tile(self.gw.reshape(1, k*k, 1, 1), (channels, 1, height, width))[None]
        # shift
        self.shift = Shift(channels, k)
        self.sigma_color = 2*sigma_color**2

        self.to(device=device)

    def forward(self, I):
        b, c, h, w = I.shape
        Is = self.shift(I).data
        Is = Is.view(b, c, -1, h, w)
        Iex = torch.repeat_interleave(I, repeats=self.k*self.k, dim=1).view(b, c, -1, h, w)
        D = (Is-Iex)**2 # here we are actually missing some sum over groups of channels
        De = torch.exp(-D / self.sigma_color)
        Dd = De * self.g
        W_denom = torch.sum(Dd, dim=2)
        If = torch.sum(Dd*Is, dim=2) / W_denom
        return If

@minh-nguyenhoang
Copy link

I've made some changes to the layer:

  • Making shift operator parallel using convolution sampling (which is proved to be better optimized than unfold)
  • Redesign the Bilateral Filter using convolution for optimization (Unfortunately this still does not solve the problem of memory inefficiency)
  • Making this layer channel-agnostic (it should work with any kind of image)
  • Remove some really unneccesary inputs like image width/height and channel
import torch
from torch import nn
import torch.nn.functional as F


def gkern2d(l=21, sig=3, device='cpu'):
    """Returns a 2D Gaussian kernel array."""
    ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
    xx, yy = torch.meshgrid(ax, ax)
    kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
    return kernel / kernel.sum() # Early normalization


class Shift(nn.Module):
    def __init__(self, kernel_size=3):
        super(Shift, self).__init__()
        kernel_size = kernel_size//2 * 2 + 1
        self.kernel_size = kernel_size
        self.pad = kernel_size//2
        self.kernels = nn.Parameter(F.one_hot(torch.arange(0, kernel_size*kernel_size)).view(kernel_size*kernel_size, 1, kernel_size, kernel_size).float())

    def forward(self, x: torch.Tensor):
        B,C,H,W = x.shape
        x_ = x.view(-1, 1, H,W)
        x_ = F.conv2d(x_, self.kernels, padding= self.pad)
        x = x_.view(B,-1,H,W)
        return x

class BilateralFilter(nn.Module):
    """BilateralFilter computes:
        If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
    """

    def __init__(self, k=7, sigma_space=5, sigma_color=0.1, device='cpu'):
        super(BilateralFilter, self).__init__()
        k = k//2 * 2 + 1
        # space gaussian kernel
        self.gw = nn.Parameter(gkern2d(k, sigma_space, device=device).view(1,-1,1,1))
        
        self.k = k
        # shift
        self.shift = Shift(k)
        self.sigma_color = 2*sigma_color**2

        if device is not None:
            self.to(device=device)

    def forward(self, I: torch.Tensor):
        B,C,H,W = I.shape
        Is = self.shift(I).view(-1,self.k*self.k,H,W)
        Iex = I.view(-1,1,H,W)
        D = (Is-Iex)**2 # here we are actually missing some sum over groups of channels
        De = torch.exp(-D / self.sigma_color)
        De = De/torch.sum(De*self.gw,1, keepdim= True)
        If = F.conv2d(De*Is, self.gw)
        return If.view(B,C,H,W)


if __name__ == '__main__':
    from skimage import data as ski_data
    from skimage.color import rgb2gray
    import time
    import matplotlib.pyplot as plt
    import numpy as np
    k = 5
    iter = 3
    device = 'cuda'
    
    img_ = ski_data.astronaut()
    # img_ = rgb2gray(img_)
    # img = img_[None, None, ...]
    # img = torch.tensor(img).float()

    img = torch.tensor(img_).unsqueeze(0).permute(0,3,1,2).float()/255
    
    if device:
        img = img.cuda()
    fig, ax = plt.subplots(1 + (iter)//4,min(1+iter, 4), figsize = (5*min(1+iter, 4),5*(1 + (iter)//4)), squeeze = False)
    ax[0,0].imshow(img_, cmap = 'gray')

    bilat = BilateralFilter(k, sigma_color=0.1, device = device)

    start_time = time.time()
    with torch.no_grad():
        for i in range(iter):
            img_filtered = bilat(img)
            img = img_filtered
            print("Duration: ", time.time() - start_time)
            img_out = img_filtered[0].cpu().permute(1,2,0).numpy().squeeze()
            ax[(1+i)//4, (1+i)%4].imshow(img_out, cmap = 'gray')
    
    plt.show()

@minh-nguyenhoang
Copy link

minh-nguyenhoang commented Sep 11, 2024

Another note is seperable bilateral filtering (this is not equivalent to normal bilateral filtering, but still produce desirable result and can scale to extremely large kernel size with much smaller computational overhead)

def gkern1d(l=21, sig=3, device='cpu'):
    """Returns a 2D Gaussian kernel array."""
    ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
    kernel = torch.exp(-(ax ** 2) / (2. * sig ** 2)).view(l, 1)
    return kernel / kernel.sum()


class SeperableShift(nn.Module):
    def __init__(self, kernel_size=3):
        super(SeperableShift, self).__init__()
        kernel_size = kernel_size//2 * 2 + 1
        self.kernel_size = kernel_size
        self.pad = kernel_size//2
        self.kernels = nn.Parameter(F.one_hot(torch.arange(0, kernel_size)).view(kernel_size, 1, kernel_size, 1).float())


    def forward(self, x: torch.Tensor, transpose = True):
        B,C,H,W = x.shape
        x_ = x.view(-1, 1, H,W)
        if transpose:
            x_ = F.conv2d(x_, self.kernels, padding= (self.pad, 0))
        else:
            x_ = F.conv2d(x_, self.kernels.transpose(-1,-2), padding= (0, self.pad))
        x = x_.view(B,-1,H,W)
        return x


class SeperableBilateralFilter(nn.Module):
    """BilateralFilter computes:
        If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
    """

    def __init__(self, k=7, sigma_space=5, sigma_color=0.1, device='cpu'):
        super(SeperableBilateralFilter, self).__init__()
        k = k//2 * 2 + 1
        # space gaussian kernel
        self.gw = nn.Parameter(gkern1d(k, sigma_space, device=device).view(1,-1,1,1))
        
        self.k = k
        # shift
        self.shift = SeperableShift(k)
        self.sigma_color = 2*sigma_color**2

        if device is not None:
            self.to(device=device)

    def forward(self, I: torch.Tensor):
        B,C,H,W = I.shape
        Is = self.shift(I, transpose = False).view(-1,self.k,H,W)
        Iex = I.view(-1,1,H,W)
        D = (Is-Iex)**2
        De = torch.exp(-D / self.sigma_color)
        De = De/torch.sum(De*self.gw,1, keepdim= True)
        If = F.conv2d(De*Is, self.gw).view(B,C,H,W)

        I = If
        Is = self.shift(I, transpose = True).view(-1,self.k,H,W)
        Iex = I.view(-1,1,H,W)
        D = (Is-Iex)**2
        De = torch.exp(-D / self.sigma_color)
        De = De/torch.sum(De*self.gw,1, keepdim= True)
        If = F.conv2d(De*Is, self.gw).view(B,C,H,W)
        return If

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment