Skip to content

Instantly share code, notes, and snippets.

@QiuJueqin
Last active February 12, 2022 09:41
Show Gist options
  • Save QiuJueqin/a7d6691ab231f4ab473e054b8e80cffe to your computer and use it in GitHub Desktop.
Save QiuJueqin/a7d6691ab231f4ab473e054b8e80cffe to your computer and use it in GitHub Desktop.
Comparison of PyTorch built-in trilinear interpolation and customized one with recursive bilinear interpolations
import torch
import torch.nn as nn
import torch.nn.functional as F
class Slicing(nn.Module):
"""
PyTorch built-in (guided) tri-linear interpolation, a.k.a. slicing in HDRNet
:param full_res: (height, width) of the full resolution image
"""
def __init__(self, full_res, align_corners):
super().__init__()
height, width = full_res
self.align_corners = align_corners
y_grid, x_grid = torch.meshgrid([
torch.arange(0, height),
torch.arange(0, width)
])
# both torch.Tensor(1, 1, H, W), normalized to [-1, 1] range
x_grid = (x_grid.float() / (width - 1) * 2 - 1).view(1, 1, height, width)
y_grid = (y_grid.float() / (height - 1) * 2 - 1).view(1, 1, height, width)
self.register_buffer('x_grid', x_grid, persistent=False)
self.register_buffer('y_grid', y_grid, persistent=False)
def forward(self, guidance, bilateral_grid):
"""
:param guidance: torch.Tensor(N, 1, H, W) in [-1, 1] range
:param bilateral_grid: torch.Tensor(N, C, grid_depth, grid_height, grid_width)
:return: torch.Tensor(N, C, H, W), slicing result
"""
batch_size = guidance.shape[0]
x_grid = self.x_grid.expand(batch_size, -1, -1, -1)
y_grid = self.y_grid.expand(batch_size, -1, -1, -1)
position = torch.stack([x_grid, y_grid, guidance], dim=4) # (N, 1, H, W, 3)
builtin_output = F.grid_sample(
bilateral_grid,
position,
mode='bilinear',
align_corners=self.align_corners
).squeeze(2) # (N, C, H, W)
customized_output = interp3d(
bilateral_grid,
position,
align_corners=self.align_corners
)
return builtin_output, customized_output
def interp3d(x, position, align_corners):
"""
Tri-linear interpolation, with recursive customized bi-linear interpolations
:param x: (torch.Tensor): Input tensor, shape (N, C, D, H, W)
:param position: (torch.Tensor): Point coordinates, shape (N, 1, Hg, Wg, 3), where suffix `g`
for guidance (in full resolution)
:param align_corners: (bool): If set to True, the extrema (-1 and 1) are considered as
referring to the center points of the input’s corner pixels. If set to False, they are
instead considered as referring to the corner points of the input’s corner pixels,
making the sampling more resolution agnostic.
:return: torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
"""
_, channel, depth, _, _ = x.shape
spatial_grid = position[:, 0, ..., :2] # (N, Hg, Wg, 2)
intensity_grid = position[:, 0, ..., 2:] # (N, Hg, Wg, 1)
stack = torch.stack([
interp2d(x[:, :, i, ...], spatial_grid, align_corners)
for i in range(depth)
], dim=4) # (N, C, Hg, Wg, D)
padded_stack = F.pad(stack, pad=[1, 1], mode='constant', value=0) # (N, C, Hg, Wg, D+2)
if align_corners:
z = ((intensity_grid + 1) / 2) * (depth - 1) + 1 # +1 for padding
else:
z = ((intensity_grid + 1) * depth - 1) / 2 + 1
z0 = torch.floor(z)
z1 = z0 + 1
w0 = (z1 - z).unsqueeze(1)
w1 = (z - z0).unsqueeze(1)
# Clip coordinates to padded image size
z0.clamp_(0, depth + 1)
z1.clamp_(0, depth + 1)
z0 = z0.unsqueeze(1).expand(-1, channel, -1, -1, -1).long() # (N, C, Hg, Wg, 1)
z1 = z1.unsqueeze(1).expand(-1, channel, -1, -1, -1).long() # (N, C, Hg, Wg, 1)
output = (
w0 * torch.gather(padded_stack, dim=4, index=z0) +
w1 * torch.gather(padded_stack, dim=4, index=z1)
).squeeze(-1)
return output
def interp2d(x, position, align_corners):
"""
Given an input and a flow-field grid, computes the output using input values and pixel
locations from grid. Supported only bi-linear interpolation method to sample the input pixels
:param x: (torch.Tensor): Input feature map, shape (N, C, H, W)
:param position: (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
:param align_corners: (bool): If set to True, the extrema (-1 and 1) are considered as
referring to the center points of the input’s corner pixels. If set to False, they are
instead considered as referring to the corner points of the input’s corner pixels,
making the sampling more resolution agnostic.
:return: torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
"""
n, c, h, w = x.shape
gn, gh, gw, _ = position.shape
# Apply default for grid_sample function zero padding
padded_image = F.pad(x, pad=[1, 1, 1, 1], mode='constant', value=0)
x, y = torch.chunk(position, 2, dim=-1) # both (N, Hg, Wg, 1)
if align_corners:
x = ((x + 1) / 2) * (w - 1) + 1 # +1 for padding
y = ((y + 1) / 2) * (h - 1) + 1
else:
x = ((x + 1) * w - 1) / 2 + 1
y = ((y + 1) * h - 1) / 2 + 1
x = x.view(n, -1)
y = y.view(n, -1)
x0 = torch.floor(x)
y0 = torch.floor(y)
x1 = x0 + 1
y1 = y0 + 1
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
wd = ((x - x0) * (y - y0)).unsqueeze(1)
# Clip coordinates to padded image size
x0.clamp_(0, w + 1)
x1.clamp_(0, w + 1)
y0.clamp_(0, h + 1)
y1.clamp_(0, h + 1)
padded_image = padded_image.view(n, c, -1)
x0_y0 = (x0 + y0 * (w + 2)).unsqueeze(1).expand(-1, c, -1).long()
x0_y1 = (x0 + y1 * (w + 2)).unsqueeze(1).expand(-1, c, -1).long()
x1_y0 = (x1 + y0 * (w + 2)).unsqueeze(1).expand(-1, c, -1).long()
x1_y1 = (x1 + y1 * (w + 2)).unsqueeze(1).expand(-1, c, -1).long()
result = (
wa * torch.gather(padded_image, 2, x0_y0) +
wb * torch.gather(padded_image, 2, x0_y1) +
wc * torch.gather(padded_image, 2, x1_y0) +
wd * torch.gather(padded_image, 2, x1_y1)
).view(n, c, gh, gw)
return result
if __name__ == '__main__':
batch_size = 4
full_res = (512, 768)
bilateral_grid_size = (8, 12)
bilateral_grid_depth = 8
C = 5 # arbitrary int, the number of elements in a voxel
slicing = Slicing(full_res=full_res, align_corners=True)
guidance = torch.randn(size=(batch_size, 1, *full_res))
guidance = torch.tanh(guidance) # [-1, 1] range
bilateral_grid = torch.randn(size=(batch_size, C, bilateral_grid_depth, *bilateral_grid_size))
output1, output2 = slicing(guidance, bilateral_grid)
print('all close with atol=1E-6: ', torch.allclose(output1, output2, atol=1e-6))
print('all close with atol=1E-5: ', torch.allclose(output1, output2, atol=1e-5))
@QiuJueqin
Copy link
Author

Running result:

all close with atol=1E-6:  False
all close with atol=1E-5:  True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment