Created
March 22, 2018 07:51
-
-
Save etienne87/9f903b2b16389f9fe98a18fade6df74b to your computer and use it in GitHub Desktop.
how to make a bilateral filter using torch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# torch_bilateral: bi/trilateral filtering in torch | |
import torch | |
from torch import nn | |
from torch.autograd import Variable | |
import torch.nn.functional as F | |
from torch.nn import Parameter | |
import numpy as np | |
import pdb | |
import time | |
def gkern2d(l=21, sig=3): | |
"""Returns a 2D Gaussian kernel array.""" | |
ax = np.arange(-l // 2 + 1., l // 2 + 1.) | |
xx, yy = np.meshgrid(ax, ax) | |
kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2)) | |
return kernel | |
class Shift(nn.Module): | |
def __init__(self, in_planes, kernel_size=3): | |
super(Shift, self).__init__() | |
self.in_planes = in_planes | |
self.kernel_size = kernel_size | |
self.channels_per_group = self.in_planes // (self.kernel_size ** 2) | |
if self.kernel_size == 3: | |
self.pad = 1 | |
elif self.kernel_size == 5: | |
self.pad = 2 | |
elif self.kernel_size == 7: | |
self.pad = 3 | |
def forward(self, x): | |
n, c, h, w = x.size() | |
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad)) | |
# Alias for convenience | |
cpg = self.channels_per_group | |
cat_layers = [] | |
for i in range(self.in_planes): | |
#Parse in row-major | |
for y in range(0,self.kernel_size): | |
y2 = y+h | |
for x in range(0, self.kernel_size): | |
x2 = x+w | |
xx = x_pad[:,i:i+1,y:y2,x:x2] | |
cat_layers += [xx] | |
return torch.cat(cat_layers, 1) | |
class BilateralFilter(nn.Module): | |
r"""BilateralFilter computes: | |
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||)) | |
""" | |
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1): | |
super(BilateralFilter, self).__init__() | |
#space gaussian kernel | |
#FIXME: do everything in torch | |
self.g = Parameter(torch.Tensor(channels,k*k)) | |
self.gw = gkern2d(k,sigma_space) | |
gw = np.tile(self.gw.reshape(channels,k*k,1,1),(1,1,height,width)) | |
self.g.data = torch.from_numpy(gw).float() | |
#shift | |
self.shift = Shift(channels,k) | |
self.sigma_color = 2*sigma_color**2 | |
def forward(self, I): | |
Is = self.shift(I).data | |
Iex = I.expand(*Is.size()) | |
D = (Is-Iex)**2 #here we are actually missing some sum over groups of channels | |
De = torch.exp(-D / self.sigma_color) | |
Dd = De * self.g.data | |
W_denom = torch.sum(Dd,dim=1) | |
If = torch.sum(Dd*Is,dim=1) / W_denom | |
return If | |
if __name__ == '__main__': | |
import matplotlib.pyplot as plt | |
import cv2 | |
c,h,w = 1,480/2,640/2 | |
k = 5 | |
cuda = False | |
bilat = BilateralFilter(c,k,h,w) | |
if cuda: | |
bilat.cuda() | |
im = cv2.imread('/home/eperot/Pictures/Lena.png', cv2.IMREAD_GRAYSCALE) | |
im = cv2.resize(im,(w,h),interpolation=cv2.INTER_CUBIC) | |
im_in = im.reshape(1,1,h,w) | |
img = torch.from_numpy(im_in).float() / 255.0 | |
if cuda: | |
img_in = img.cuda() | |
else: | |
img_in = img | |
y = bilat(img_in) | |
start = time.time() | |
y = bilat(img_in) | |
print(time.time()-start) | |
img_out = y.cpu().numpy()[0] #not counting the return transfer in timing! | |
show_out = cv2.resize(img_out,(640,480)) | |
show_in = cv2.resize(img[0,0].numpy(),(640,480)) | |
# diff = np.abs(img_out - img[0,0]) | |
# diff = (diff - diff.min()) / (diff.max() - diff.min()) | |
# cv2.namedWindow('diff') | |
# cv2.moveWindow('diff',50,50) | |
# cv2.imshow('diff', diff) | |
cv2.namedWindow('kernel') | |
cv2.imshow('kernel', bilat.gw) | |
cv2.namedWindow('img_out') | |
cv2.moveWindow('img_out', 200, 200) | |
cv2.imshow('img_out',show_out) | |
# | |
cv2.namedWindow('img_in') | |
cv2.imshow('img_in', show_in) | |
cv2.waitKey(0) | |
# n = gkern2d(5,3) | |
# | |
# s = n.reshape(1,25) | |
# | |
# print(n) | |
# print(s) | |
# plt.imshow(n, interpolation='none') | |
# plt.show() | |
anyone has any idea on changing it for rgb image case?
maybe it works for rgb image
import torch
from torch import nn
import torch.nn.functional as F
def gkern2d(l=21, sig=3, device='cpu'):
"""Returns a 2D Gaussian kernel array."""
ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
xx, yy = torch.meshgrid(ax, ax)
kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
return kernel
class Shift(nn.Module):
def __init__(self, in_planes, kernel_size=3):
super(Shift, self).__init__()
self.in_planes = in_planes
self.kernel_size = kernel_size
self.channels_per_group = self.in_planes // (self.kernel_size ** 2)
if self.kernel_size == 3:
self.pad = 1
elif self.kernel_size == 5:
self.pad = 2
elif self.kernel_size == 7:
self.pad = 3
def forward(self, x):
n, c, h, w = x.size()
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
# Alias for convenience
cpg = self.channels_per_group
cat_layers = []
for i in range(self.in_planes):
# Parse in row-major
for y in range(0,self.kernel_size):
y2 = y+h
for x in range(0, self.kernel_size):
x2 = x+w
xx = x_pad[:,i:i+1,y:y2,x:x2]
cat_layers += [xx]
return torch.cat(cat_layers, 1)
class BilateralFilter(nn.Module):
"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1, device='cpu'):
super(BilateralFilter, self).__init__()
self.k = k
# space gaussian kernel
self.gw = gkern2d(k, sigma_space, device=device)
self.g = torch.tile(self.gw.reshape(1, k*k, 1, 1), (channels, 1, height, width))[None]
# shift
self.shift = Shift(channels, k)
self.sigma_color = 2*sigma_color**2
self.to(device=device)
def forward(self, I):
b, c, h, w = I.shape
Is = self.shift(I).data
Is = Is.view(b, c, -1, h, w)
Iex = torch.repeat_interleave(I, repeats=self.k*self.k, dim=1).view(b, c, -1, h, w)
D = (Is-Iex)**2 # here we are actually missing some sum over groups of channels
De = torch.exp(-D / self.sigma_color)
Dd = De * self.g
W_denom = torch.sum(Dd, dim=2)
If = torch.sum(Dd*Is, dim=2) / W_denom
return If
I've made some changes to the layer:
- Making shift operator parallel using convolution sampling (which is proved to be better optimized than unfold)
- Redesign the Bilateral Filter using convolution for optimization (Unfortunately this still does not solve the problem of memory inefficiency)
- Making this layer channel-agnostic (it should work with any kind of image)
- Remove some really unneccesary inputs like image width/height and channel
import torch
from torch import nn
import torch.nn.functional as F
def gkern2d(l=21, sig=3, device='cpu'):
"""Returns a 2D Gaussian kernel array."""
ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
xx, yy = torch.meshgrid(ax, ax)
kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
return kernel / kernel.sum() # Early normalization
class Shift(nn.Module):
def __init__(self, kernel_size=3):
super(Shift, self).__init__()
kernel_size = kernel_size//2 * 2 + 1
self.kernel_size = kernel_size
self.pad = kernel_size//2
self.kernels = nn.Parameter(F.one_hot(torch.arange(0, kernel_size*kernel_size)).view(kernel_size*kernel_size, 1, kernel_size, kernel_size).float())
def forward(self, x: torch.Tensor):
B,C,H,W = x.shape
x_ = x.view(-1, 1, H,W)
x_ = F.conv2d(x_, self.kernels, padding= self.pad)
x = x_.view(B,-1,H,W)
return x
class BilateralFilter(nn.Module):
"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, k=7, sigma_space=5, sigma_color=0.1, device='cpu'):
super(BilateralFilter, self).__init__()
k = k//2 * 2 + 1
# space gaussian kernel
self.gw = nn.Parameter(gkern2d(k, sigma_space, device=device).view(1,-1,1,1))
self.k = k
# shift
self.shift = Shift(k)
self.sigma_color = 2*sigma_color**2
if device is not None:
self.to(device=device)
def forward(self, I: torch.Tensor):
B,C,H,W = I.shape
Is = self.shift(I).view(-1,self.k*self.k,H,W)
Iex = I.view(-1,1,H,W)
D = (Is-Iex)**2 # here we are actually missing some sum over groups of channels
De = torch.exp(-D / self.sigma_color)
De = De/torch.sum(De*self.gw,1, keepdim= True)
If = F.conv2d(De*Is, self.gw)
return If.view(B,C,H,W)
if __name__ == '__main__':
from skimage import data as ski_data
from skimage.color import rgb2gray
import time
import matplotlib.pyplot as plt
import numpy as np
k = 5
iter = 3
device = 'cuda'
img_ = ski_data.astronaut()
# img_ = rgb2gray(img_)
# img = img_[None, None, ...]
# img = torch.tensor(img).float()
img = torch.tensor(img_).unsqueeze(0).permute(0,3,1,2).float()/255
if device:
img = img.cuda()
fig, ax = plt.subplots(1 + (iter)//4,min(1+iter, 4), figsize = (5*min(1+iter, 4),5*(1 + (iter)//4)), squeeze = False)
ax[0,0].imshow(img_, cmap = 'gray')
bilat = BilateralFilter(k, sigma_color=0.1, device = device)
start_time = time.time()
with torch.no_grad():
for i in range(iter):
img_filtered = bilat(img)
img = img_filtered
print("Duration: ", time.time() - start_time)
img_out = img_filtered[0].cpu().permute(1,2,0).numpy().squeeze()
ax[(1+i)//4, (1+i)%4].imshow(img_out, cmap = 'gray')
plt.show()
Another note is seperable bilateral filtering (this is not equivalent to normal bilateral filtering, but still produce desirable result and can scale to extremely large kernel size with much smaller computational overhead)
def gkern1d(l=21, sig=3, device='cpu'):
"""Returns a 2D Gaussian kernel array."""
ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
kernel = torch.exp(-(ax ** 2) / (2. * sig ** 2)).view(l, 1)
return kernel / kernel.sum()
class SeperableShift(nn.Module):
def __init__(self, kernel_size=3):
super(SeperableShift, self).__init__()
kernel_size = kernel_size//2 * 2 + 1
self.kernel_size = kernel_size
self.pad = kernel_size//2
self.kernels = nn.Parameter(F.one_hot(torch.arange(0, kernel_size)).view(kernel_size, 1, kernel_size, 1).float())
def forward(self, x: torch.Tensor, transpose = True):
B,C,H,W = x.shape
x_ = x.view(-1, 1, H,W)
if transpose:
x_ = F.conv2d(x_, self.kernels, padding= (self.pad, 0))
else:
x_ = F.conv2d(x_, self.kernels.transpose(-1,-2), padding= (0, self.pad))
x = x_.view(B,-1,H,W)
return x
class SeperableBilateralFilter(nn.Module):
"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, k=7, sigma_space=5, sigma_color=0.1, device='cpu'):
super(SeperableBilateralFilter, self).__init__()
k = k//2 * 2 + 1
# space gaussian kernel
self.gw = nn.Parameter(gkern1d(k, sigma_space, device=device).view(1,-1,1,1))
self.k = k
# shift
self.shift = SeperableShift(k)
self.sigma_color = 2*sigma_color**2
if device is not None:
self.to(device=device)
def forward(self, I: torch.Tensor):
B,C,H,W = I.shape
Is = self.shift(I, transpose = False).view(-1,self.k,H,W)
Iex = I.view(-1,1,H,W)
D = (Is-Iex)**2
De = torch.exp(-D / self.sigma_color)
De = De/torch.sum(De*self.gw,1, keepdim= True)
If = F.conv2d(De*Is, self.gw).view(B,C,H,W)
I = If
Is = self.shift(I, transpose = True).view(-1,self.k,H,W)
Iex = I.view(-1,1,H,W)
D = (Is-Iex)**2
De = torch.exp(-D / self.sigma_color)
De = De/torch.sum(De*self.gw,1, keepdim= True)
If = F.conv2d(De*Is, self.gw).view(B,C,H,W)
return If
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You would change the channel parameter to channels=3. However, I noticed that it will then throw an exception in both the original version and my modified version. Sadly, I have no understanding of how a bilateral filter works and am therefore unable to fix this error. I needed it myself only to implement a baseline.