Created
May 28, 2017 20:44
-
-
Save jcjohnson/876ca05163ad23ab06c2f98cf3bcd6bb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def bilinear_sample(feats, X, Y, idx): | |
""" | |
Perform bilinear sampling on the features in feats using the sampling grid | |
given by X and Y. | |
Inputs: | |
- feats: Tensor (or Variable) holding input feature map, of shape (N, C, H, W) | |
- X, Y: Tensors (or Variables) holding x and y coordinates of the sampling | |
grids; both have shape shape (B, HH, WW). Elements in X should be in the | |
range [0, W - 1] and elements in Y should be in the range [0, H - 1]. | |
- idx: LongTensor (or Variable) of shape (B,) mapping elements of the sampling | |
grids to elements in feats. In particular idx[i] = j means that X[i], Y[i] | |
is a sampling grid for feats[j]. | |
Returns: | |
- out: Tensor (or Variable) of shape (B, C, HH, WW) where out[i] is computed | |
by sampling from feats[idx[i]] using the sampling grid (X[i], Y[i]). | |
""" | |
N, C, H, W = feats.size() | |
assert X.size() == Y.size() | |
B, HH, WW = X.size() | |
outs, mask_idxs = [], [] | |
for i in range(N): | |
# Figure out which elements of X and Y correspond to element i of feats. | |
# We need a bit of special-case logic for Tensors vs Variables. | |
mask = idx.eq(i) | |
if torch.is_tensor(idx): | |
BB = mask.sum() | |
else: | |
assert isinstance(mask, torch.autograd.Variable) | |
BB = mask.data.sum() | |
if BB == 0: | |
continue | |
if torch.is_tensor(idx): | |
mask_idx = mask.nonzero()[:, 0] | |
elif isinstance(mask, torch.autograd.Variable): | |
mask_idx = torch.autograd.Variable(mask.data.nonzero()[:, 0]) | |
x = X.index_select(0, mask_idx) | |
y = Y.index_select(0, mask_idx) | |
mask_idxs.append(mask_idx) | |
# Get the x and y coordinates for the four samples | |
x0 = x.floor().clamp(min=0, max=W-1) | |
x1 = (x0 + 1).clamp(min=0, max=W-1) | |
y0 = y.floor().clamp(min=0, max=H-1) | |
y1 = (y0 + 1).clamp(min=0, max=H-1) | |
# In numpy we could do something like feats[i, :, y0, x0] to pull out | |
# the elements of feats at coordinates y0 and x0, but PyTorch doesn't | |
# yet support this style of indexing. Instead we have to use the gather | |
# method, which only allows us to index along one dimension at a time; | |
# therefore we will collapse the features (BB, C, H, W) into (BB, C, H * W) | |
# and index along the last dimension. Below we generate linear indices into | |
# the collapsed last dimension for each of the four combinations we need. | |
y0x0_idx = (W * y0 + x0).view(BB, 1, HH * WW).expand(BB, C, HH * WW) | |
y1x0_idx = (W * y1 + x0).view(BB, 1, HH * WW).expand(BB, C, HH * WW) | |
y0x1_idx = (W * y0 + x1).view(BB, 1, HH * WW).expand(BB, C, HH * WW) | |
y1x1_idx = (W * y1 + x1).view(BB, 1, HH * WW).expand(BB, C, HH * WW) | |
# Actually use gather to pull out the values from feats corresponding | |
# to our four samples, then reshape them to (BB, C, HH, WW) | |
feats_i_flat = feats[i].view(1, C, H * W).expand(BB, C, H * W) | |
v1 = feats_i_flat.gather(2, y0x0_idx.long()).view(BB, C, HH, WW) | |
v2 = feats_i_flat.gather(2, y1x0_idx.long()).view(BB, C, HH, WW) | |
v3 = feats_i_flat.gather(2, y0x1_idx.long()).view(BB, C, HH, WW) | |
v4 = feats_i_flat.gather(2, y1x1_idx.long()).view(BB, C, HH, WW) | |
# Compute the weights for the four samples | |
w1 = ((x1 - x) * (y1 - y)).view(BB, 1, HH, WW).expand(BB, C, HH, WW) | |
w2 = ((x1 - x) * (y - y0)).view(BB, 1, HH, WW).expand(BB, C, HH, WW) | |
w3 = ((x - x0) * (y1 - y)).view(BB, 1, HH, WW).expand(BB, C, HH, WW) | |
w4 = ((x - x0) * (y - y0)).view(BB, 1, HH, WW).expand(BB, C, HH, WW) | |
# Multiply the samples by the weights to give our interpolated results. | |
cur_out = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 | |
outs.append(cur_out) | |
mask_idxs = torch.cat(mask_idxs, 0) | |
_, sidx = mask_idxs.sort() | |
return torch.cat(outs, 0).index_select(0, sidx) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
do you have any usage example for this? it's not clear to me what should be idx tensor