Skip to content

Instantly share code, notes, and snippets.

@DuaneNielsen
Created October 2, 2021 17:47
Show Gist options
  • Save DuaneNielsen/407faec83dc9509385c4b8c1a6350149 to your computer and use it in GitHub Desktop.
Save DuaneNielsen/407faec83dc9509385c4b8c1a6350149 to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Function
from matplotlib import pyplot as plt
from torchvision.io import read_image
inf = float('inf')
def test_conv2d():
inp = torch.randn(1, 3, 10, 12)
w = torch.randn(2, 3, 4, 5)
inp_unf = F.unfold(inp, (4, 5))
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
out = F.fold(out_unf, (7, 8), (1, 1))
# or equivalently (and avoiding a copy),
out = out_unf.view(1, 2, 7, 8)
(F.conv2d(inp, w) - out).abs().max()
def conv2d(x, w, kernel_size, stride=1, padding=0):
N, C, H, W = x.shape
C_OUT, C_IN, K_H, K_W = w.shape
x_unf = F.unfold(x, kernel_size, stride=stride, padding=(padding,))
N_UNF, C_H_W, ACTIVATION_PIXELS = x_unf.shape
assert N == N_UNF
assert C_H_W == C_IN * K_H * K_W
w = w.view(w.size(0), -1)
a = torch.matmul(w, x_unf)
return a.view(N, C_OUT, H, W)
def test_conv2d_func():
x = torch.ones(2, 2, 4, 10)
w = torch.randn(7, 2, 3, 3)
a = conv2d(x, w, (3, 3), stride=1, padding=1)
a_ = F.conv2d(x, w, padding=1)
assert torch.allclose(a, a_)
class BuildVolume2d(nn.Module):
def __init__(self, maxdisp):
super(BuildVolume2d, self).__init__()
self.maxdisp = maxdisp
def forward(self, feat_l, feat_r):
padded_feat_r = F.pad(feat_r, [self.maxdisp - 1, 0, 0, 0])
cost = torch.zeros((feat_l.size()[0], self.maxdisp, feat_l.size()[2], feat_l.size()[3]), device='cuda')
for i in range(0, self.maxdisp):
if i > 0:
# pdb.set_trace()
cost[:, i, :, :] = torch.norm(feat_l[:, :, :, :] - padded_feat_r[:, :, :, self.maxdisp - 1 - i:-i:4], 1,
1)
else:
# pdb.set_trace()
cost[:, i, :, :] = torch.norm(feat_l[:, :, :, :] - padded_feat_r[:, :, :, self.maxdisp - 1::4], 1, 1)
return cost.contiguous() # B*D*H*W
# Inherit from Function
class LinearFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().mm(input)
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
def test_linear():
input = torch.ones(1, 2)
weights = torch.ones(1, 2, requires_grad=True) / 2.0
weights.retain_grad()
linear = LinearFunction.apply
output = linear(input, weights)
output.backward()
print(output)
print(weights.grad)
def flat_patches_3x3(image):
N, C, H, W = image.shape
image_unf = F.unfold(image, kernel_size=3, padding=1)
return image_unf.view(N, -1, H, W)
def unfold_right(right, max_disparity=3):
"""
:param right: N, C, H, W
:param max_disparity: int
:return: N, C, D, H, W The disparity axis has reversed order. Max disparity is at index 0, 0 disparity is at index
max_disparity-1, unfortunately pytorch does not have a free flip transform, so fixing this would cause overhead
imagine a scan line L [0, 1, 2, 3] R [0, 1, 2, 3]
the matching matrix is...
RIGHT
[ 0, 1, 2, 3 ]
0 0 x x x
LEFT 1 1 0 x x
2 2 1 0 x
3 3 2 1 0
this is equivalent to ...
RIGHT
[ inf, inf, inf, 0, 1, 2, 3]
0 [ inf, inf, inf, 0, x, x, x ]
LEFT 1 [ inf, inf, 1 0, x, x ]
2 [ inf, 2, 1, 0, x ]
3 [ 3, 2, 1, 0 ]
which can be obtained by left padding RIGHT by kernel size -1
then unfolding RIGHT with kernel (1, kernel size)
RIGHT
[ inf, inf, inf, 0, 1, 2, 3]
0 [ inf, inf, inf, 0]
LEFT 1 [ inf, inf, 1 0]
2 [ inf, 2, 1, 0]
3 [ 3, 2, 1, 0 ]
why do it like this? unfold returns a VIEW, so no memory is allocated in this operation
kernel size can now used to set the maximum disparity, saving further memory
"""
N, C, H, W = right.shape
right_pad = F.pad(right, [max_disparity - 1, 0], value=float('inf'))
right_unfold = F.unfold(right_pad, kernel_size=(1, max_disparity))
right_view = right_unfold.view(N, C, max_disparity, H, W)
return right_view
def conv2d_view(image, kernel_size):
N, C, H, W = image.shape
if isinstance(kernel_size, tuple):
H_unf, W_unf = H - kernel_size[0] // 2, W - kernel_size[1] // 2
else:
H_unf, W_unf = H - kernel_size // 2, W - kernel_size // 2
image_unf = F.unfold(image, kernel_size=kernel_size)
return image_unf.view(N, -1, H_unf, W_unf)
def sad(left, right):
"""
:param left: N, C, 1, H, W
:param right_unf: N, C, M, H, W
:return: N, C, M, H, W
"""
return torch.norm(left - right, p=1, dim=1)
def disparity_map(left, right, tile_size, distance_f, maxdisp):
left = conv2d_view(left, kernel_size=tile_size)
right = conv2d_view(right, kernel_size=tile_size)
left_broadcastable = left.unsqueeze(2)
right_unf = unfold_right(right, max_disparity=maxdisp)
volume = distance_f(left_broadcastable, right_unf)
volume = volume.flip(1)
return torch.argmin(volume, dim=1)
def cost_volume(left_features, right_features, max_disparity, distance_f):
left_broadcastable = left_features.unsqueeze(2)
right_unf = unfold_right(right_features, max_disparity=max_disparity)
return distance_f(left_broadcastable, right_unf)
def compute_mask_index(width, max_disparity):
i = torch.triu_indices(row=max_disparity, col=width, offset=width - max_disparity + 1)
i[0] = max_disparity - 1 - i[0]
return i
def render_mask(i, width, max_disparity):
m = torch.zeros(max_disparity, width)
m[i[0], i[1]] = float('inf')
return m
def test_compute_disparity_matrix():
print("")
W, D = 4, 4
left = torch.arange(W, dtype=torch.float32).reshape(1, 1, 1, W)
right = torch.arange(W, dtype=torch.float32).reshape(1, 1, 1, W)
"""
RIGHT
0 [ inf, inf, inf, 0]
LEFT 1 [ inf, inf, 1, 0]
2 [ inf, 2, 1, 0]
3 [ 3, 2, 1, 0]
"""
inf = float('inf')
expected_right_unf = torch.tensor([
[inf, inf, inf, 0],
[inf, inf, 0, 1],
[inf, 0, 1, 2],
[0, 1, 2, 3]
]).view(1, 1, 4, 1, 4)
right_unf = unfold_right(right, max_disparity=D)
assert torch.allclose(expected_right_unf, right_unf)
disparity = left - right_unf
print(disparity[0, 0])
def test_flat_patches():
image = torch.arange(3 * 3, dtype=torch.float32).reshape(1, 1, 3, 3)
patches = flat_patches_3x3(image)
N, C, H, W = patches.shape
assert (patches.shape == N, 9, H, W)
"""
[ 0, 0, 0, 0, 0 ]
[ 0, 0, 1, 2, 0 ]
[ 0, 3, 4, 5, 0 ]
[ 0, 6, 7, 8, 0 ]
[ 0, 0, 0, 0, 0 ]
[0, :, 0, 0] should be...
[ 0, 0, 0 ]
[ 0, 0, 1 ]
[ 0, 3, 4 ]
"""
assert patches[0, :, 0, 0].allclose(torch.tensor([0.0, 0, 0, 0, 0, 1, 0, 3, 4]))
"""
[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
[0, : 1, 1] == [0, 1, 2, 3, 4, 5, 6, 7, 8]
"""
assert patches[0, :, 1, 1].allclose(torch.tensor([0.0, 1, 2, 3, 4, 5, 6, 7, 8]))
H, W = 2, 2
image = torch.arange(H * W, dtype=torch.float32).reshape(1, 1, H, W)
patches = flat_patches_3x3(image)
assert patches.shape == (N, 9, H, W)
H, W = 5, 4
image = torch.arange(H * W, dtype=torch.float32).reshape(1, 1, H, W)
patches = flat_patches_3x3(image)
assert patches.shape == (N, 9, H, W)
def test_unfoldright_2d():
"""
RIGHT
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
RIGHT TOP UNFOLD (H=0)
[inf, inf, 0],
[inf, 0, 1],
[0, 1, 2],
RIGHT MIDDLE UNFOLD (H=1)
[inf, inf, 3],
[inf, 3, 4],
[3, 4, 5],
RIGHT BOTTOM UNFOLD (H=2)
[inf, inf, 6],
[inf, 6, 7],
[6, 7, 8],
"""
# expected_right_unf = torch.tensor([
# ]).view(1, 1, 3, 1, 3)
right = torch.arange(3 * 3, dtype=torch.float32).reshape(1, 1, 3, 3)
print(right)
right_unf = unfold_right(right, max_disparity=3)
assert torch.allclose(right_unf[0, :, :, 0, :],
torch.tensor([
[inf, inf, 0],
[inf, 0, 1],
[0, 1, 2],
]))
assert torch.allclose(right_unf[0, :, :, 1, :],
torch.tensor([
[inf, inf, 3],
[inf, 3, 4],
[3, 4, 5],
]))
assert torch.allclose(right_unf[0, :, :, 2, :],
torch.tensor([
[inf, inf, 6],
[inf, 6, 7],
[6, 7, 8],
]))
def test_unfoldright_2d_multi_channel():
"""
same as single channel, but instead of a single value in each channel,
each channel contains the a 3x3 image patch
RIGHT
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
RIGHT TOP UNFOLD (H=0)
[inf, inf, 0],
[inf, 0, 1],
[0, 1, 2],
RIGHT MIDDLE UNFOLD (H=1)
[inf, inf, 3],
[inf, 3, 4],
[3, 4, 5],
RIGHT BOTTOM UNFOLD (H=2)
[inf, inf, 6],
[inf, 6, 7],
[6, 7, 8],
"""
right = torch.arange(3 * 3 * 1, dtype=torch.float32).reshape(1, 1, 3, 3)
right_patch = flat_patches_3x3(right)
right_unf = unfold_right(right_patch, max_disparity=3)
pad = torch.tensor([inf, inf, inf, inf, inf, inf, inf, inf, inf])
_0 = right_patch[0, :, 0, 0]
_1 = right_patch[0, :, 0, 1]
_2 = right_patch[0, :, 0, 2]
_3 = right_patch[0, :, 0, 0]
_4 = right_patch[0, :, 0, 1]
_5 = right_patch[0, :, 0, 2]
_6 = right_patch[0, :, 0, 0]
_7 = right_patch[0, :, 0, 1]
_8 = right_patch[0, :, 0, 2]
assert torch.allclose(right_unf[0, :, :, 0, :], torch.stack([
torch.stack([pad, pad, _0]),
torch.stack([pad, _0, _1]),
torch.stack([_0, _1, _2]),
]).permute(2, 0, 1))
assert torch.allclose(right_unf[0, :, :, 0, :], torch.stack([
torch.stack([pad, pad, _3]),
torch.stack([pad, _3, _4]),
torch.stack([_3, _4, _5]),
]).permute(2, 0, 1))
assert torch.allclose(right_unf[0, :, :, 0, :], torch.stack([
torch.stack([pad, pad, _6]),
torch.stack([pad, _6, _7]),
torch.stack([_6, _7, _8]),
]).permute(2, 0, 1))
def test_compute_similarity():
left = torch.arange(3 * 4 * 1, dtype=torch.float32).reshape(1, 1, 3, 4)
right = torch.arange(3 * 4 * 1, dtype=torch.float32).reshape(1, 1, 3, 4)
left_patch = flat_patches_3x3(left)
right_patch = flat_patches_3x3(right)
right_unf = unfold_right(right_patch, max_disparity=3)
cost = left_patch.unsqueeze(2) - right_unf
cost_volume = torch.norm(cost, p=1, dim=1, keepdim=True)
assert torch.allclose(cost[0, :, 2, 0, 0], torch.zeros(9))
assert torch.allclose(cost[0, :, 2, 0, 1], torch.zeros(9))
assert torch.allclose(cost[0, :, 2, 0, 2], torch.zeros(9))
print(cost_volume.shape)
def test_disparity_map():
left = torch.arange(3 * 4 * 1, dtype=torch.float32).reshape(1, 1, 3, 4)
right = torch.arange(3 * 4 * 1, dtype=torch.float32).reshape(1, 1, 3, 4)
dmap = disparity_map(left, right, 3, sad, 3)
assert torch.allclose(dmap, torch.zeros_like(dmap))
def test_read():
read_image('~/data/')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment