Last active
February 12, 2022 09:41
-
-
Save QiuJueqin/a7d6691ab231f4ab473e054b8e80cffe to your computer and use it in GitHub Desktop.
Comparison of PyTorch built-in trilinear interpolation and customized one with recursive bilinear interpolations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Slicing(nn.Module): | |
""" | |
PyTorch built-in (guided) tri-linear interpolation, a.k.a. slicing in HDRNet | |
:param full_res: (height, width) of the full resolution image | |
""" | |
def __init__(self, full_res, align_corners): | |
super().__init__() | |
height, width = full_res | |
self.align_corners = align_corners | |
y_grid, x_grid = torch.meshgrid([ | |
torch.arange(0, height), | |
torch.arange(0, width) | |
]) | |
# both torch.Tensor(1, 1, H, W), normalized to [-1, 1] range | |
x_grid = (x_grid.float() / (width - 1) * 2 - 1).view(1, 1, height, width) | |
y_grid = (y_grid.float() / (height - 1) * 2 - 1).view(1, 1, height, width) | |
self.register_buffer('x_grid', x_grid, persistent=False) | |
self.register_buffer('y_grid', y_grid, persistent=False) | |
def forward(self, guidance, bilateral_grid): | |
""" | |
:param guidance: torch.Tensor(N, 1, H, W) in [-1, 1] range | |
:param bilateral_grid: torch.Tensor(N, C, grid_depth, grid_height, grid_width) | |
:return: torch.Tensor(N, C, H, W), slicing result | |
""" | |
batch_size = guidance.shape[0] | |
x_grid = self.x_grid.expand(batch_size, -1, -1, -1) | |
y_grid = self.y_grid.expand(batch_size, -1, -1, -1) | |
position = torch.stack([x_grid, y_grid, guidance], dim=4) # (N, 1, H, W, 3) | |
builtin_output = F.grid_sample( | |
bilateral_grid, | |
position, | |
mode='bilinear', | |
align_corners=self.align_corners | |
).squeeze(2) # (N, C, H, W) | |
customized_output = interp3d( | |
bilateral_grid, | |
position, | |
align_corners=self.align_corners | |
) | |
return builtin_output, customized_output | |
def interp3d(x, position, align_corners): | |
""" | |
Tri-linear interpolation, with recursive customized bi-linear interpolations | |
:param x: (torch.Tensor): Input tensor, shape (N, C, D, H, W) | |
:param position: (torch.Tensor): Point coordinates, shape (N, 1, Hg, Wg, 3), where suffix `g` | |
for guidance (in full resolution) | |
:param align_corners: (bool): If set to True, the extrema (-1 and 1) are considered as | |
referring to the center points of the input’s corner pixels. If set to False, they are | |
instead considered as referring to the corner points of the input’s corner pixels, | |
making the sampling more resolution agnostic. | |
:return: torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) | |
""" | |
_, channel, depth, _, _ = x.shape | |
spatial_grid = position[:, 0, ..., :2] # (N, Hg, Wg, 2) | |
intensity_grid = position[:, 0, ..., 2:] # (N, Hg, Wg, 1) | |
stack = torch.stack([ | |
interp2d(x[:, :, i, ...], spatial_grid, align_corners) | |
for i in range(depth) | |
], dim=4) # (N, C, Hg, Wg, D) | |
padded_stack = F.pad(stack, pad=[1, 1], mode='constant', value=0) # (N, C, Hg, Wg, D+2) | |
if align_corners: | |
z = ((intensity_grid + 1) / 2) * (depth - 1) + 1 # +1 for padding | |
else: | |
z = ((intensity_grid + 1) * depth - 1) / 2 + 1 | |
z0 = torch.floor(z) | |
z1 = z0 + 1 | |
w0 = (z1 - z).unsqueeze(1) | |
w1 = (z - z0).unsqueeze(1) | |
# Clip coordinates to padded image size | |
z0.clamp_(0, depth + 1) | |
z1.clamp_(0, depth + 1) | |
z0 = z0.unsqueeze(1).expand(-1, channel, -1, -1, -1).long() # (N, C, Hg, Wg, 1) | |
z1 = z1.unsqueeze(1).expand(-1, channel, -1, -1, -1).long() # (N, C, Hg, Wg, 1) | |
output = ( | |
w0 * torch.gather(padded_stack, dim=4, index=z0) + | |
w1 * torch.gather(padded_stack, dim=4, index=z1) | |
).squeeze(-1) | |
return output | |
def interp2d(x, position, align_corners): | |
""" | |
Given an input and a flow-field grid, computes the output using input values and pixel | |
locations from grid. Supported only bi-linear interpolation method to sample the input pixels | |
:param x: (torch.Tensor): Input feature map, shape (N, C, H, W) | |
:param position: (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) | |
:param align_corners: (bool): If set to True, the extrema (-1 and 1) are considered as | |
referring to the center points of the input’s corner pixels. If set to False, they are | |
instead considered as referring to the corner points of the input’s corner pixels, | |
making the sampling more resolution agnostic. | |
:return: torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) | |
""" | |
n, c, h, w = x.shape | |
gn, gh, gw, _ = position.shape | |
# Apply default for grid_sample function zero padding | |
padded_image = F.pad(x, pad=[1, 1, 1, 1], mode='constant', value=0) | |
x, y = torch.chunk(position, 2, dim=-1) # both (N, Hg, Wg, 1) | |
if align_corners: | |
x = ((x + 1) / 2) * (w - 1) + 1 # +1 for padding | |
y = ((y + 1) / 2) * (h - 1) + 1 | |
else: | |
x = ((x + 1) * w - 1) / 2 + 1 | |
y = ((y + 1) * h - 1) / 2 + 1 | |
x = x.view(n, -1) | |
y = y.view(n, -1) | |
x0 = torch.floor(x) | |
y0 = torch.floor(y) | |
x1 = x0 + 1 | |
y1 = y0 + 1 | |
wa = ((x1 - x) * (y1 - y)).unsqueeze(1) | |
wb = ((x1 - x) * (y - y0)).unsqueeze(1) | |
wc = ((x - x0) * (y1 - y)).unsqueeze(1) | |
wd = ((x - x0) * (y - y0)).unsqueeze(1) | |
# Clip coordinates to padded image size | |
x0.clamp_(0, w + 1) | |
x1.clamp_(0, w + 1) | |
y0.clamp_(0, h + 1) | |
y1.clamp_(0, h + 1) | |
padded_image = padded_image.view(n, c, -1) | |
x0_y0 = (x0 + y0 * (w + 2)).unsqueeze(1).expand(-1, c, -1).long() | |
x0_y1 = (x0 + y1 * (w + 2)).unsqueeze(1).expand(-1, c, -1).long() | |
x1_y0 = (x1 + y0 * (w + 2)).unsqueeze(1).expand(-1, c, -1).long() | |
x1_y1 = (x1 + y1 * (w + 2)).unsqueeze(1).expand(-1, c, -1).long() | |
result = ( | |
wa * torch.gather(padded_image, 2, x0_y0) + | |
wb * torch.gather(padded_image, 2, x0_y1) + | |
wc * torch.gather(padded_image, 2, x1_y0) + | |
wd * torch.gather(padded_image, 2, x1_y1) | |
).view(n, c, gh, gw) | |
return result | |
if __name__ == '__main__': | |
batch_size = 4 | |
full_res = (512, 768) | |
bilateral_grid_size = (8, 12) | |
bilateral_grid_depth = 8 | |
C = 5 # arbitrary int, the number of elements in a voxel | |
slicing = Slicing(full_res=full_res, align_corners=True) | |
guidance = torch.randn(size=(batch_size, 1, *full_res)) | |
guidance = torch.tanh(guidance) # [-1, 1] range | |
bilateral_grid = torch.randn(size=(batch_size, C, bilateral_grid_depth, *bilateral_grid_size)) | |
output1, output2 = slicing(guidance, bilateral_grid) | |
print('all close with atol=1E-6: ', torch.allclose(output1, output2, atol=1e-6)) | |
print('all close with atol=1E-5: ', torch.allclose(output1, output2, atol=1e-5)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Running result: