Created
October 2, 2021 17:47
-
-
Save DuaneNielsen/407faec83dc9509385c4b8c1a6350149 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
import torch.nn as nn | |
from torch.autograd import Function | |
from matplotlib import pyplot as plt | |
from torchvision.io import read_image | |
inf = float('inf') | |
def test_conv2d(): | |
inp = torch.randn(1, 3, 10, 12) | |
w = torch.randn(2, 3, 4, 5) | |
inp_unf = F.unfold(inp, (4, 5)) | |
out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2) | |
out = F.fold(out_unf, (7, 8), (1, 1)) | |
# or equivalently (and avoiding a copy), | |
out = out_unf.view(1, 2, 7, 8) | |
(F.conv2d(inp, w) - out).abs().max() | |
def conv2d(x, w, kernel_size, stride=1, padding=0): | |
N, C, H, W = x.shape | |
C_OUT, C_IN, K_H, K_W = w.shape | |
x_unf = F.unfold(x, kernel_size, stride=stride, padding=(padding,)) | |
N_UNF, C_H_W, ACTIVATION_PIXELS = x_unf.shape | |
assert N == N_UNF | |
assert C_H_W == C_IN * K_H * K_W | |
w = w.view(w.size(0), -1) | |
a = torch.matmul(w, x_unf) | |
return a.view(N, C_OUT, H, W) | |
def test_conv2d_func(): | |
x = torch.ones(2, 2, 4, 10) | |
w = torch.randn(7, 2, 3, 3) | |
a = conv2d(x, w, (3, 3), stride=1, padding=1) | |
a_ = F.conv2d(x, w, padding=1) | |
assert torch.allclose(a, a_) | |
class BuildVolume2d(nn.Module): | |
def __init__(self, maxdisp): | |
super(BuildVolume2d, self).__init__() | |
self.maxdisp = maxdisp | |
def forward(self, feat_l, feat_r): | |
padded_feat_r = F.pad(feat_r, [self.maxdisp - 1, 0, 0, 0]) | |
cost = torch.zeros((feat_l.size()[0], self.maxdisp, feat_l.size()[2], feat_l.size()[3]), device='cuda') | |
for i in range(0, self.maxdisp): | |
if i > 0: | |
# pdb.set_trace() | |
cost[:, i, :, :] = torch.norm(feat_l[:, :, :, :] - padded_feat_r[:, :, :, self.maxdisp - 1 - i:-i:4], 1, | |
1) | |
else: | |
# pdb.set_trace() | |
cost[:, i, :, :] = torch.norm(feat_l[:, :, :, :] - padded_feat_r[:, :, :, self.maxdisp - 1::4], 1, 1) | |
return cost.contiguous() # B*D*H*W | |
# Inherit from Function | |
class LinearFunction(Function): | |
# Note that both forward and backward are @staticmethods | |
@staticmethod | |
# bias is an optional argument | |
def forward(ctx, input, weight, bias=None): | |
ctx.save_for_backward(input, weight, bias) | |
output = input.mm(weight.t()) | |
if bias is not None: | |
output += bias.unsqueeze(0).expand_as(output) | |
return output | |
# This function has only a single output, so it gets only one gradient | |
@staticmethod | |
def backward(ctx, grad_output): | |
# This is a pattern that is very convenient - at the top of backward | |
# unpack saved_tensors and initialize all gradients w.r.t. inputs to | |
# None. Thanks to the fact that additional trailing Nones are | |
# ignored, the return statement is simple even when the function has | |
# optional inputs. | |
input, weight, bias = ctx.saved_tensors | |
grad_input = grad_weight = grad_bias = None | |
# These needs_input_grad checks are optional and there only to | |
# improve efficiency. If you want to make your code simpler, you can | |
# skip them. Returning gradients for inputs that don't require it is | |
# not an error. | |
if ctx.needs_input_grad[0]: | |
grad_input = grad_output.mm(weight) | |
if ctx.needs_input_grad[1]: | |
grad_weight = grad_output.t().mm(input) | |
if bias is not None and ctx.needs_input_grad[2]: | |
grad_bias = grad_output.sum(0) | |
return grad_input, grad_weight, grad_bias | |
def test_linear(): | |
input = torch.ones(1, 2) | |
weights = torch.ones(1, 2, requires_grad=True) / 2.0 | |
weights.retain_grad() | |
linear = LinearFunction.apply | |
output = linear(input, weights) | |
output.backward() | |
print(output) | |
print(weights.grad) | |
def flat_patches_3x3(image): | |
N, C, H, W = image.shape | |
image_unf = F.unfold(image, kernel_size=3, padding=1) | |
return image_unf.view(N, -1, H, W) | |
def unfold_right(right, max_disparity=3): | |
""" | |
:param right: N, C, H, W | |
:param max_disparity: int | |
:return: N, C, D, H, W The disparity axis has reversed order. Max disparity is at index 0, 0 disparity is at index | |
max_disparity-1, unfortunately pytorch does not have a free flip transform, so fixing this would cause overhead | |
imagine a scan line L [0, 1, 2, 3] R [0, 1, 2, 3] | |
the matching matrix is... | |
RIGHT | |
[ 0, 1, 2, 3 ] | |
0 0 x x x | |
LEFT 1 1 0 x x | |
2 2 1 0 x | |
3 3 2 1 0 | |
this is equivalent to ... | |
RIGHT | |
[ inf, inf, inf, 0, 1, 2, 3] | |
0 [ inf, inf, inf, 0, x, x, x ] | |
LEFT 1 [ inf, inf, 1 0, x, x ] | |
2 [ inf, 2, 1, 0, x ] | |
3 [ 3, 2, 1, 0 ] | |
which can be obtained by left padding RIGHT by kernel size -1 | |
then unfolding RIGHT with kernel (1, kernel size) | |
RIGHT | |
[ inf, inf, inf, 0, 1, 2, 3] | |
0 [ inf, inf, inf, 0] | |
LEFT 1 [ inf, inf, 1 0] | |
2 [ inf, 2, 1, 0] | |
3 [ 3, 2, 1, 0 ] | |
why do it like this? unfold returns a VIEW, so no memory is allocated in this operation | |
kernel size can now used to set the maximum disparity, saving further memory | |
""" | |
N, C, H, W = right.shape | |
right_pad = F.pad(right, [max_disparity - 1, 0], value=float('inf')) | |
right_unfold = F.unfold(right_pad, kernel_size=(1, max_disparity)) | |
right_view = right_unfold.view(N, C, max_disparity, H, W) | |
return right_view | |
def conv2d_view(image, kernel_size): | |
N, C, H, W = image.shape | |
if isinstance(kernel_size, tuple): | |
H_unf, W_unf = H - kernel_size[0] // 2, W - kernel_size[1] // 2 | |
else: | |
H_unf, W_unf = H - kernel_size // 2, W - kernel_size // 2 | |
image_unf = F.unfold(image, kernel_size=kernel_size) | |
return image_unf.view(N, -1, H_unf, W_unf) | |
def sad(left, right): | |
""" | |
:param left: N, C, 1, H, W | |
:param right_unf: N, C, M, H, W | |
:return: N, C, M, H, W | |
""" | |
return torch.norm(left - right, p=1, dim=1) | |
def disparity_map(left, right, tile_size, distance_f, maxdisp): | |
left = conv2d_view(left, kernel_size=tile_size) | |
right = conv2d_view(right, kernel_size=tile_size) | |
left_broadcastable = left.unsqueeze(2) | |
right_unf = unfold_right(right, max_disparity=maxdisp) | |
volume = distance_f(left_broadcastable, right_unf) | |
volume = volume.flip(1) | |
return torch.argmin(volume, dim=1) | |
def cost_volume(left_features, right_features, max_disparity, distance_f): | |
left_broadcastable = left_features.unsqueeze(2) | |
right_unf = unfold_right(right_features, max_disparity=max_disparity) | |
return distance_f(left_broadcastable, right_unf) | |
def compute_mask_index(width, max_disparity): | |
i = torch.triu_indices(row=max_disparity, col=width, offset=width - max_disparity + 1) | |
i[0] = max_disparity - 1 - i[0] | |
return i | |
def render_mask(i, width, max_disparity): | |
m = torch.zeros(max_disparity, width) | |
m[i[0], i[1]] = float('inf') | |
return m | |
def test_compute_disparity_matrix(): | |
print("") | |
W, D = 4, 4 | |
left = torch.arange(W, dtype=torch.float32).reshape(1, 1, 1, W) | |
right = torch.arange(W, dtype=torch.float32).reshape(1, 1, 1, W) | |
""" | |
RIGHT | |
0 [ inf, inf, inf, 0] | |
LEFT 1 [ inf, inf, 1, 0] | |
2 [ inf, 2, 1, 0] | |
3 [ 3, 2, 1, 0] | |
""" | |
inf = float('inf') | |
expected_right_unf = torch.tensor([ | |
[inf, inf, inf, 0], | |
[inf, inf, 0, 1], | |
[inf, 0, 1, 2], | |
[0, 1, 2, 3] | |
]).view(1, 1, 4, 1, 4) | |
right_unf = unfold_right(right, max_disparity=D) | |
assert torch.allclose(expected_right_unf, right_unf) | |
disparity = left - right_unf | |
print(disparity[0, 0]) | |
def test_flat_patches(): | |
image = torch.arange(3 * 3, dtype=torch.float32).reshape(1, 1, 3, 3) | |
patches = flat_patches_3x3(image) | |
N, C, H, W = patches.shape | |
assert (patches.shape == N, 9, H, W) | |
""" | |
[ 0, 0, 0, 0, 0 ] | |
[ 0, 0, 1, 2, 0 ] | |
[ 0, 3, 4, 5, 0 ] | |
[ 0, 6, 7, 8, 0 ] | |
[ 0, 0, 0, 0, 0 ] | |
[0, :, 0, 0] should be... | |
[ 0, 0, 0 ] | |
[ 0, 0, 1 ] | |
[ 0, 3, 4 ] | |
""" | |
assert patches[0, :, 0, 0].allclose(torch.tensor([0.0, 0, 0, 0, 0, 1, 0, 3, 4])) | |
""" | |
[0, 1, 2] | |
[3, 4, 5] | |
[6, 7, 8] | |
[0, : 1, 1] == [0, 1, 2, 3, 4, 5, 6, 7, 8] | |
""" | |
assert patches[0, :, 1, 1].allclose(torch.tensor([0.0, 1, 2, 3, 4, 5, 6, 7, 8])) | |
H, W = 2, 2 | |
image = torch.arange(H * W, dtype=torch.float32).reshape(1, 1, H, W) | |
patches = flat_patches_3x3(image) | |
assert patches.shape == (N, 9, H, W) | |
H, W = 5, 4 | |
image = torch.arange(H * W, dtype=torch.float32).reshape(1, 1, H, W) | |
patches = flat_patches_3x3(image) | |
assert patches.shape == (N, 9, H, W) | |
def test_unfoldright_2d(): | |
""" | |
RIGHT | |
[0, 1, 2], | |
[3, 4, 5], | |
[6, 7, 8] | |
RIGHT TOP UNFOLD (H=0) | |
[inf, inf, 0], | |
[inf, 0, 1], | |
[0, 1, 2], | |
RIGHT MIDDLE UNFOLD (H=1) | |
[inf, inf, 3], | |
[inf, 3, 4], | |
[3, 4, 5], | |
RIGHT BOTTOM UNFOLD (H=2) | |
[inf, inf, 6], | |
[inf, 6, 7], | |
[6, 7, 8], | |
""" | |
# expected_right_unf = torch.tensor([ | |
# ]).view(1, 1, 3, 1, 3) | |
right = torch.arange(3 * 3, dtype=torch.float32).reshape(1, 1, 3, 3) | |
print(right) | |
right_unf = unfold_right(right, max_disparity=3) | |
assert torch.allclose(right_unf[0, :, :, 0, :], | |
torch.tensor([ | |
[inf, inf, 0], | |
[inf, 0, 1], | |
[0, 1, 2], | |
])) | |
assert torch.allclose(right_unf[0, :, :, 1, :], | |
torch.tensor([ | |
[inf, inf, 3], | |
[inf, 3, 4], | |
[3, 4, 5], | |
])) | |
assert torch.allclose(right_unf[0, :, :, 2, :], | |
torch.tensor([ | |
[inf, inf, 6], | |
[inf, 6, 7], | |
[6, 7, 8], | |
])) | |
def test_unfoldright_2d_multi_channel(): | |
""" | |
same as single channel, but instead of a single value in each channel, | |
each channel contains the a 3x3 image patch | |
RIGHT | |
[0, 1, 2], | |
[3, 4, 5], | |
[6, 7, 8] | |
RIGHT TOP UNFOLD (H=0) | |
[inf, inf, 0], | |
[inf, 0, 1], | |
[0, 1, 2], | |
RIGHT MIDDLE UNFOLD (H=1) | |
[inf, inf, 3], | |
[inf, 3, 4], | |
[3, 4, 5], | |
RIGHT BOTTOM UNFOLD (H=2) | |
[inf, inf, 6], | |
[inf, 6, 7], | |
[6, 7, 8], | |
""" | |
right = torch.arange(3 * 3 * 1, dtype=torch.float32).reshape(1, 1, 3, 3) | |
right_patch = flat_patches_3x3(right) | |
right_unf = unfold_right(right_patch, max_disparity=3) | |
pad = torch.tensor([inf, inf, inf, inf, inf, inf, inf, inf, inf]) | |
_0 = right_patch[0, :, 0, 0] | |
_1 = right_patch[0, :, 0, 1] | |
_2 = right_patch[0, :, 0, 2] | |
_3 = right_patch[0, :, 0, 0] | |
_4 = right_patch[0, :, 0, 1] | |
_5 = right_patch[0, :, 0, 2] | |
_6 = right_patch[0, :, 0, 0] | |
_7 = right_patch[0, :, 0, 1] | |
_8 = right_patch[0, :, 0, 2] | |
assert torch.allclose(right_unf[0, :, :, 0, :], torch.stack([ | |
torch.stack([pad, pad, _0]), | |
torch.stack([pad, _0, _1]), | |
torch.stack([_0, _1, _2]), | |
]).permute(2, 0, 1)) | |
assert torch.allclose(right_unf[0, :, :, 0, :], torch.stack([ | |
torch.stack([pad, pad, _3]), | |
torch.stack([pad, _3, _4]), | |
torch.stack([_3, _4, _5]), | |
]).permute(2, 0, 1)) | |
assert torch.allclose(right_unf[0, :, :, 0, :], torch.stack([ | |
torch.stack([pad, pad, _6]), | |
torch.stack([pad, _6, _7]), | |
torch.stack([_6, _7, _8]), | |
]).permute(2, 0, 1)) | |
def test_compute_similarity(): | |
left = torch.arange(3 * 4 * 1, dtype=torch.float32).reshape(1, 1, 3, 4) | |
right = torch.arange(3 * 4 * 1, dtype=torch.float32).reshape(1, 1, 3, 4) | |
left_patch = flat_patches_3x3(left) | |
right_patch = flat_patches_3x3(right) | |
right_unf = unfold_right(right_patch, max_disparity=3) | |
cost = left_patch.unsqueeze(2) - right_unf | |
cost_volume = torch.norm(cost, p=1, dim=1, keepdim=True) | |
assert torch.allclose(cost[0, :, 2, 0, 0], torch.zeros(9)) | |
assert torch.allclose(cost[0, :, 2, 0, 1], torch.zeros(9)) | |
assert torch.allclose(cost[0, :, 2, 0, 2], torch.zeros(9)) | |
print(cost_volume.shape) | |
def test_disparity_map(): | |
left = torch.arange(3 * 4 * 1, dtype=torch.float32).reshape(1, 1, 3, 4) | |
right = torch.arange(3 * 4 * 1, dtype=torch.float32).reshape(1, 1, 3, 4) | |
dmap = disparity_map(left, right, 3, sad, 3) | |
assert torch.allclose(dmap, torch.zeros_like(dmap)) | |
def test_read(): | |
read_image('~/data/') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment