Created
August 3, 2022 13:43
-
-
Save torridgristle/ed572d416c9acc9d1495d8bb25fb715d to your computer and use it in GitHub Desktop.
Max Pool 2d Unpooling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Perform max pool 2d with indicies on a tensor | |
max_size = 8 | |
max_output, max_indices = F.max_pool2d_with_indices(input_tensor,max_size) | |
# Unpool it to get a tensor of the original size with zeros in all non-max areas | |
max_unpool = F.max_unpool2d(max_output,max_indices,max_size,max_size) | |
# Unpool it using a tensor of ones with the same indices to get ones where the tensor was sampled | |
max_mask = F.max_unpool2d(torch.ones_like(max_output),max_indices,max_size,max_size) | |
# Makes a kernel that's round and the distance from the center | |
def DistanceKernel(size=9): | |
blur_kernel = torch.cat(torch.meshgrid(2*[torch.linspace(-1,1,size+2)[1:-1]])).reshape(1,2,size,size) | |
blur_kernel = 1-blur_kernel.norm(2,1,True) | |
blur_kernel = blur_kernel.relu().reshape(1,1,size,size) | |
blur_kernel = blur_kernel / blur_kernel.sum() | |
return blur_kernel | |
def CustomUnpooling(x, mask, width, pow=1): | |
pad = (width-1)/2 | |
pad = [math.floor(pad),math.ceil(pad),math.floor(pad),math.ceil(pad)] | |
kernel = DistanceKernel(width) ** pow | |
x = F.pad(x,pad,'constant',0.0) | |
x_weighted = F.conv2d(x,kernel.expand(x.shape[1],-1,-1,-1),None,1,groups=x.shape[1]) | |
mask = F.pad(mask,pad,'constant',0.0) | |
mask_weighted = F.conv2d(mask,kernel.expand(mask.shape[1],-1,-1,-1),None,1,groups=mask.shape[1]) | |
output = x_weighted / mask_weighted | |
return output | |
# The pow argument makes it sharper as it goes up, 8 seems to be a reasonable upper limit. | |
# Otherwise it can get blurry as width increases and masked areas overlap. | |
# Depending on the sparsity of some areas you might need a width that's almost double the max pool's kernel size | |
smooth_unpool = CustomUnpooling(max_unpool,max_mask,16,8) | |
# Example outputs https://imgur.com/a/qGFpSUO it attempts to spread out known values until it hits another known value, | |
# appearing almost like voronoi cells. In the last example the samples are equally spaced to show that it will | |
# create square shapes if the input is equally spaced since it's dependent on | |
# the locations of sampled values, not the values thesmelves. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment