-
-
Save etienne87/9f903b2b16389f9fe98a18fade6df74b to your computer and use it in GitHub Desktop.
#!/usr/bin/python | |
# torch_bilateral: bi/trilateral filtering in torch | |
import torch | |
from torch import nn | |
from torch.autograd import Variable | |
import torch.nn.functional as F | |
from torch.nn import Parameter | |
import numpy as np | |
import pdb | |
import time | |
def gkern2d(l=21, sig=3): | |
"""Returns a 2D Gaussian kernel array.""" | |
ax = np.arange(-l // 2 + 1., l // 2 + 1.) | |
xx, yy = np.meshgrid(ax, ax) | |
kernel = np.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2)) | |
return kernel | |
class Shift(nn.Module): | |
def __init__(self, in_planes, kernel_size=3): | |
super(Shift, self).__init__() | |
self.in_planes = in_planes | |
self.kernel_size = kernel_size | |
self.channels_per_group = self.in_planes // (self.kernel_size ** 2) | |
if self.kernel_size == 3: | |
self.pad = 1 | |
elif self.kernel_size == 5: | |
self.pad = 2 | |
elif self.kernel_size == 7: | |
self.pad = 3 | |
def forward(self, x): | |
n, c, h, w = x.size() | |
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad)) | |
# Alias for convenience | |
cpg = self.channels_per_group | |
cat_layers = [] | |
for i in range(self.in_planes): | |
#Parse in row-major | |
for y in range(0,self.kernel_size): | |
y2 = y+h | |
for x in range(0, self.kernel_size): | |
x2 = x+w | |
xx = x_pad[:,i:i+1,y:y2,x:x2] | |
cat_layers += [xx] | |
return torch.cat(cat_layers, 1) | |
class BilateralFilter(nn.Module): | |
r"""BilateralFilter computes: | |
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||)) | |
""" | |
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1): | |
super(BilateralFilter, self).__init__() | |
#space gaussian kernel | |
#FIXME: do everything in torch | |
self.g = Parameter(torch.Tensor(channels,k*k)) | |
self.gw = gkern2d(k,sigma_space) | |
gw = np.tile(self.gw.reshape(channels,k*k,1,1),(1,1,height,width)) | |
self.g.data = torch.from_numpy(gw).float() | |
#shift | |
self.shift = Shift(channels,k) | |
self.sigma_color = 2*sigma_color**2 | |
def forward(self, I): | |
Is = self.shift(I).data | |
Iex = I.expand(*Is.size()) | |
D = (Is-Iex)**2 #here we are actually missing some sum over groups of channels | |
De = torch.exp(-D / self.sigma_color) | |
Dd = De * self.g.data | |
W_denom = torch.sum(Dd,dim=1) | |
If = torch.sum(Dd*Is,dim=1) / W_denom | |
return If | |
if __name__ == '__main__': | |
import matplotlib.pyplot as plt | |
import cv2 | |
c,h,w = 1,480/2,640/2 | |
k = 5 | |
cuda = False | |
bilat = BilateralFilter(c,k,h,w) | |
if cuda: | |
bilat.cuda() | |
im = cv2.imread('/home/eperot/Pictures/Lena.png', cv2.IMREAD_GRAYSCALE) | |
im = cv2.resize(im,(w,h),interpolation=cv2.INTER_CUBIC) | |
im_in = im.reshape(1,1,h,w) | |
img = torch.from_numpy(im_in).float() / 255.0 | |
if cuda: | |
img_in = img.cuda() | |
else: | |
img_in = img | |
y = bilat(img_in) | |
start = time.time() | |
y = bilat(img_in) | |
print(time.time()-start) | |
img_out = y.cpu().numpy()[0] #not counting the return transfer in timing! | |
show_out = cv2.resize(img_out,(640,480)) | |
show_in = cv2.resize(img[0,0].numpy(),(640,480)) | |
# diff = np.abs(img_out - img[0,0]) | |
# diff = (diff - diff.min()) / (diff.max() - diff.min()) | |
# cv2.namedWindow('diff') | |
# cv2.moveWindow('diff',50,50) | |
# cv2.imshow('diff', diff) | |
cv2.namedWindow('kernel') | |
cv2.imshow('kernel', bilat.gw) | |
cv2.namedWindow('img_out') | |
cv2.moveWindow('img_out', 200, 200) | |
cv2.imshow('img_out',show_out) | |
# | |
cv2.namedWindow('img_in') | |
cv2.imshow('img_in', show_in) | |
cv2.waitKey(0) | |
# n = gkern2d(5,3) | |
# | |
# s = n.reshape(1,25) | |
# | |
# print(n) | |
# print(s) | |
# plt.imshow(n, interpolation='none') | |
# plt.show() | |
looks fine :)
coming back to this thing several years later i'm wondering if pytorch's unfold is perhaps simpler than this "Shift" operator
how to change it for the rgb image case?
You would change the channel parameter to channels=3. However, I noticed that it will then throw an exception in both the original version and my modified version. Sadly, I have no understanding of how a bilateral filter works and am therefore unable to fix this error. I needed it myself only to implement a baseline.
anyone has any idea on changing it for rgb image case?
maybe it works for rgb image
import torch
from torch import nn
import torch.nn.functional as F
def gkern2d(l=21, sig=3, device='cpu'):
"""Returns a 2D Gaussian kernel array."""
ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
xx, yy = torch.meshgrid(ax, ax)
kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
return kernel
class Shift(nn.Module):
def __init__(self, in_planes, kernel_size=3):
super(Shift, self).__init__()
self.in_planes = in_planes
self.kernel_size = kernel_size
self.channels_per_group = self.in_planes // (self.kernel_size ** 2)
if self.kernel_size == 3:
self.pad = 1
elif self.kernel_size == 5:
self.pad = 2
elif self.kernel_size == 7:
self.pad = 3
def forward(self, x):
n, c, h, w = x.size()
x_pad = F.pad(x, (self.pad, self.pad, self.pad, self.pad))
# Alias for convenience
cpg = self.channels_per_group
cat_layers = []
for i in range(self.in_planes):
# Parse in row-major
for y in range(0,self.kernel_size):
y2 = y+h
for x in range(0, self.kernel_size):
x2 = x+w
xx = x_pad[:,i:i+1,y:y2,x:x2]
cat_layers += [xx]
return torch.cat(cat_layers, 1)
class BilateralFilter(nn.Module):
"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, channels=3, k=7, height=480, width=640, sigma_space=5, sigma_color=0.1, device='cpu'):
super(BilateralFilter, self).__init__()
self.k = k
# space gaussian kernel
self.gw = gkern2d(k, sigma_space, device=device)
self.g = torch.tile(self.gw.reshape(1, k*k, 1, 1), (channels, 1, height, width))[None]
# shift
self.shift = Shift(channels, k)
self.sigma_color = 2*sigma_color**2
self.to(device=device)
def forward(self, I):
b, c, h, w = I.shape
Is = self.shift(I).data
Is = Is.view(b, c, -1, h, w)
Iex = torch.repeat_interleave(I, repeats=self.k*self.k, dim=1).view(b, c, -1, h, w)
D = (Is-Iex)**2 # here we are actually missing some sum over groups of channels
De = torch.exp(-D / self.sigma_color)
Dd = De * self.g
W_denom = torch.sum(Dd, dim=2)
If = torch.sum(Dd*Is, dim=2) / W_denom
return If
I've made some changes to the layer:
- Making shift operator parallel using convolution sampling (which is proved to be better optimized than unfold)
- Redesign the Bilateral Filter using convolution for optimization (Unfortunately this still does not solve the problem of memory inefficiency)
- Making this layer channel-agnostic (it should work with any kind of image)
- Remove some really unneccesary inputs like image width/height and channel
import torch
from torch import nn
import torch.nn.functional as F
def gkern2d(l=21, sig=3, device='cpu'):
"""Returns a 2D Gaussian kernel array."""
ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
xx, yy = torch.meshgrid(ax, ax)
kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sig ** 2))
return kernel / kernel.sum() # Early normalization
class Shift(nn.Module):
def __init__(self, kernel_size=3):
super(Shift, self).__init__()
kernel_size = kernel_size//2 * 2 + 1
self.kernel_size = kernel_size
self.pad = kernel_size//2
self.kernels = nn.Parameter(F.one_hot(torch.arange(0, kernel_size*kernel_size)).view(kernel_size*kernel_size, 1, kernel_size, kernel_size).float())
def forward(self, x: torch.Tensor):
B,C,H,W = x.shape
x_ = x.view(-1, 1, H,W)
x_ = F.conv2d(x_, self.kernels, padding= self.pad)
x = x_.view(B,-1,H,W)
return x
class BilateralFilter(nn.Module):
"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, k=7, sigma_space=5, sigma_color=0.1, device='cpu'):
super(BilateralFilter, self).__init__()
k = k//2 * 2 + 1
# space gaussian kernel
self.gw = nn.Parameter(gkern2d(k, sigma_space, device=device).view(1,-1,1,1))
self.k = k
# shift
self.shift = Shift(k)
self.sigma_color = 2*sigma_color**2
if device is not None:
self.to(device=device)
def forward(self, I: torch.Tensor):
B,C,H,W = I.shape
Is = self.shift(I).view(-1,self.k*self.k,H,W)
Iex = I.view(-1,1,H,W)
D = (Is-Iex)**2 # here we are actually missing some sum over groups of channels
De = torch.exp(-D / self.sigma_color)
De = De/torch.sum(De*self.gw,1, keepdim= True)
If = F.conv2d(De*Is, self.gw)
return If.view(B,C,H,W)
if __name__ == '__main__':
from skimage import data as ski_data
from skimage.color import rgb2gray
import time
import matplotlib.pyplot as plt
import numpy as np
k = 5
iter = 3
device = 'cuda'
img_ = ski_data.astronaut()
# img_ = rgb2gray(img_)
# img = img_[None, None, ...]
# img = torch.tensor(img).float()
img = torch.tensor(img_).unsqueeze(0).permute(0,3,1,2).float()/255
if device:
img = img.cuda()
fig, ax = plt.subplots(1 + (iter)//4,min(1+iter, 4), figsize = (5*min(1+iter, 4),5*(1 + (iter)//4)), squeeze = False)
ax[0,0].imshow(img_, cmap = 'gray')
bilat = BilateralFilter(k, sigma_color=0.1, device = device)
start_time = time.time()
with torch.no_grad():
for i in range(iter):
img_filtered = bilat(img)
img = img_filtered
print("Duration: ", time.time() - start_time)
img_out = img_filtered[0].cpu().permute(1,2,0).numpy().squeeze()
ax[(1+i)//4, (1+i)%4].imshow(img_out, cmap = 'gray')
plt.show()
Another note is seperable bilateral filtering (this is not equivalent to normal bilateral filtering, but still produce desirable result and can scale to extremely large kernel size with much smaller computational overhead)
def gkern1d(l=21, sig=3, device='cpu'):
"""Returns a 2D Gaussian kernel array."""
ax = torch.arange(-l // 2 + 1., l // 2 + 1., device=device)
kernel = torch.exp(-(ax ** 2) / (2. * sig ** 2)).view(l, 1)
return kernel / kernel.sum()
class SeperableShift(nn.Module):
def __init__(self, kernel_size=3):
super(SeperableShift, self).__init__()
kernel_size = kernel_size//2 * 2 + 1
self.kernel_size = kernel_size
self.pad = kernel_size//2
self.kernels = nn.Parameter(F.one_hot(torch.arange(0, kernel_size)).view(kernel_size, 1, kernel_size, 1).float())
def forward(self, x: torch.Tensor, transpose = True):
B,C,H,W = x.shape
x_ = x.view(-1, 1, H,W)
if transpose:
x_ = F.conv2d(x_, self.kernels, padding= (self.pad, 0))
else:
x_ = F.conv2d(x_, self.kernels.transpose(-1,-2), padding= (0, self.pad))
x = x_.view(B,-1,H,W)
return x
class SeperableBilateralFilter(nn.Module):
"""BilateralFilter computes:
If = 1/W * Sum_{xi C Omega}(I * f(||I(xi)-I(x)||) * g(||xi-x||))
"""
def __init__(self, k=7, sigma_space=5, sigma_color=0.1, device='cpu'):
super(SeperableBilateralFilter, self).__init__()
k = k//2 * 2 + 1
# space gaussian kernel
self.gw = nn.Parameter(gkern1d(k, sigma_space, device=device).view(1,-1,1,1))
self.k = k
# shift
self.shift = SeperableShift(k)
self.sigma_color = 2*sigma_color**2
if device is not None:
self.to(device=device)
def forward(self, I: torch.Tensor):
B,C,H,W = I.shape
Is = self.shift(I, transpose = False).view(-1,self.k,H,W)
Iex = I.view(-1,1,H,W)
D = (Is-Iex)**2
De = torch.exp(-D / self.sigma_color)
De = De/torch.sum(De*self.gw,1, keepdim= True)
If = F.conv2d(De*Is, self.gw).view(B,C,H,W)
I = If
Is = self.shift(I, transpose = True).view(-1,self.k,H,W)
Iex = I.view(-1,1,H,W)
D = (Is-Iex)**2
De = torch.exp(-D / self.sigma_color)
De = De/torch.sum(De*self.gw,1, keepdim= True)
If = F.conv2d(De*Is, self.gw).view(B,C,H,W)
return If
I created an slightly updated version of this gist (everything is torch, no opencv, mitigated use of Parameter and Variable). The images look as expected and everything looks good. @etienne87 Do you agree?