Created
August 24, 2022 18:58
-
-
Save torridgristle/24c8c672e285668d53b3b0efec7cf4db to your computer and use it in GitHub Desktop.
Blur an image with a depth map in PyTorch. Splits the map into ranges of values, multiplies the image by those ranges, blurs them and the split map, sums all the blurred images and blurred maps together, divide blurred image sum by blurred map sum.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#1 is end and 0 is start in the map. | |
def map_blur(img,map,s_start=0.375,s_end=8,steps=8): | |
img_slices = img * 0 | |
map_slices = map * 0 | |
for s in range(steps): | |
sigma = (s/(steps-1)) * (s_end-s_start) + s_start | |
slice_start = (s+0)/steps | |
slice_end = (s+1)/steps | |
map_slice = torch.logical_and( | |
torch.greater_equal(map,slice_start), | |
torch.less(map,slice_end) if slice_end != 1.0 else torch.less_equal(map,slice_end), | |
).float() | |
img_slice = img * map_slice | |
map_slices += GaussianBlur_Sigma(map_slice,sigma,False,'zeros') | |
img_slices += GaussianBlur_Sigma(img_slice,sigma,False,'zeros') | |
out = img_slices / map_slices | |
return out | |
### Simple gradients to use as maps for testing. | |
# Linear gradient | |
# test_map = torch.linspace(0,1,256).reshape(1,1,1,256) * torch.ones([1,1,256,256]) | |
# Radial gradient | |
test_map = torch.cat(torch.meshgrid(2*[torch.linspace(-1,1,256)])).reshape(1,2,256,256) | |
# random center | |
test_map += torch.rand([1,2,1,1]).mul(2).sub(1) | |
# turn into radial gradient, divide by 2**0.5 since that's the value at the corners if the center is still (0,0) | |
test_map = test_map.norm(2,1,True) / (2**0.5) | |
# Invert it so that the center is 1 and the edges are either 0 or lower | |
test_map = 1-test_map | |
# Softplus to smoothly fade to zero instead of going below zero, since the map is only going to be used 0-1 | |
test_map = F.softplus(test_map,8) | |
### Simple checker pattern to use as image for testing. | |
test_img = torch.ones([1,1,16,16]) | |
test_img = torch.cat([ | |
test_img,test_img*-1+1 | |
],-1) | |
test_img = torch.cat([ | |
test_img*-1+1,test_img | |
],-2) | |
test_img = test_img.repeat(1,1,8,8) | |
### My gaussian blurring function that only takes sigma as an input and determines what kernel size is needed for the smallest value on the kernel to be 1/255 or lower, up to sigma 20. | |
# Also has various strange things for padding with zero without fading to zero at the edges by blurring a tensor of ones padded with zeros and dividing the blurred image by that. | |
def GaussianBlur_Sigma(x, sigma=0.375, allow_even=False, pad_mode='reflect', norm_edges=True): | |
# Prediction of what input will result in 1/255 when put through exp(-0.5*(x/sigma)**2) | |
# that's then divided by the sum for points -1024 through 1024, up to sigma 20. | |
width = ((sigma**1.2)*-1.5)+sigma*4.54 | |
if width <= 1: | |
return x | |
if allow_even == True: | |
width = math.ceil(width*2+1) | |
else: | |
width = math.ceil(width)*2+1 | |
kernel = T.functional_tensor._get_gaussian_kernel1d(width,sigma).reshape(1,1,1,width).to(device) | |
pad = (width-1)*0.5 | |
pad = [math.floor(pad),math.ceil(pad),math.floor(pad),math.ceil(pad)] | |
if pad_mode == 'zeros': | |
if norm_edges == True: | |
mask = F.pad(torch.ones([1,1,x.shape[-2],x.shape[-1]],device=device),pad,'constant',value=0.0) | |
x = F.pad(x,pad,'constant',value=0.0) | |
else: | |
x_new = F.pad(x,pad,'constant',value=0.0) | |
with torch.no_grad(): | |
x_new.data = F.pad(x,pad,pad_mode) | |
x = x_new | |
x = F.conv2d(x,kernel.expand(x.shape[1],1,-1,-1),stride=1,groups=x.shape[1]) | |
x = F.conv2d(x,kernel.permute(0,1,3,2).expand(x.shape[1],1,-1,-1),stride=1,groups=x.shape[1]) | |
if pad_mode == 'zeros': | |
if norm_edges == True: | |
mask = F.conv2d(mask,kernel,stride=1,groups=1) | |
mask = F.conv2d(mask,kernel.permute(0,1,3,2),stride=1,groups=1) | |
x = x / mask.add(1e-8) | |
return x | |
### Example usage with img, map, min sigma, max sigma, and number of steps. | |
blur_out = map_blur(test_img,test_map,0.375,8,8) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment