Created
December 9, 2019 19:34
-
-
Save kice/ffaae8c68949a3be221d1f1a8b0f7a8d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SMAGL is 0 or 2 | |
// PixelType is uint8_t | |
static void warp_c(const uint8_t *srcp8, const uint8_t *edgep8, uint8_t *dstp8, int src_stride, int edge_stride, int dst_stride, int width, int height, int depth, int bits_per_sample) | |
{ | |
const uint8_t *srcp = (const uint8_t *)srcp8; | |
const uint8_t *edgep = (const uint8_t *)edgep8; | |
uint8_t *dstp = (uint8_t *)dstp8; | |
src_stride /= sizeof(uint8_t); | |
edge_stride /= sizeof(uint8_t); | |
dst_stride /= sizeof(uint8_t); | |
int pixel_max = (1 << bits_per_sample) - 1; | |
const int x_limit_min = 0; | |
const int x_limit_max = (width - 1); | |
float scale = (depth << 8) / 65536.0f; | |
for (int y = 0; y < height; y++) { | |
float y_limit_min = -y; | |
float y_limit_max = (height - y - 1) - 1e-2f; | |
for (int x = 0; x < width; x++) { | |
// calculate displacement | |
int above, below; | |
if (y == 0) | |
above = edgep[x]; | |
else | |
above = edgep[-edge_stride + x]; | |
if (y == height - 1) | |
below = edgep[x]; | |
else | |
below = edgep[edge_stride + x]; | |
int left, right; | |
if (x == 0) | |
left = edgep[x]; | |
else | |
left = edgep[x - 1]; | |
if (x == width - 1) | |
right = edgep[x]; | |
else | |
right = edgep[x + 1]; | |
float _h = (left - right) * scale; | |
float _v = (above - below) * scale; | |
_v = std::clamp(_v, y_limit_min, y_limit_max); | |
float remainder_h = fmod(_h, 1.0); | |
remainder_h = remainder_h < 0.0 ? 1.0 + remainder_h : remainder_h; | |
float remainder_v = fmod(_v, 1.0); | |
remainder_v = remainder_v < 0.0 ? 1.0 + remainder_v : remainder_v; | |
int h = floor(_h) + x, v = floor(_v); | |
remainder_h = (x_limit_max > h) && !(x_limit_min > h) ? remainder_h : 0; | |
h = std::min(h, x_limit_max); | |
h = std::max(h, x_limit_min); | |
// h and v contain the displacement now. | |
int s00 = srcp[v * src_stride + h]; | |
int s01 = srcp[v * src_stride + h + 1]; | |
int s10 = srcp[(v + 1) * src_stride + h]; | |
int s11 = srcp[(v + 1) * src_stride + h + 1]; | |
float s0 = s00 * (1.0 - remainder_h) + s01 * remainder_h + 0.5; | |
float s1 = s10 * (1.0 - remainder_h) + s11 * remainder_h + 0.5; | |
float s = s0 * (1.0 - remainder_v) + s1 * remainder_v + 0.5; | |
int val = nearbyint(s); // Use floor to match original | |
dstp[x] = std::min(std::max(val, 0), pixel_max); | |
} | |
srcp += src_stride; | |
edgep += edge_stride; | |
dstp += dst_stride; | |
} | |
} |
WolframRhodium
commented
Dec 10, 2019
•
def awarp(src, mask):
N,C,H,W = src.shape
scale = (8 * 256) / 65536
pixel_max = 1.0
y = torch.arange(H).view(1, 1, -1, 1).float()
x = torch.arange(W).view(1, 1, 1, -1).float()
coord = torch.arange(H*W).view(1, 1, H, W)
# x_limit_min = torch.zeros_like(x)
# x_limit_max = torch.full_like(x, W) - 2
x_limit_min = 0
x_limit_max = W - 2
y_limit_min = -y
y_limit_max = (H - 2) - y
above = torch.cat([mask[:, :, :1, : ], mask[:, :, :-1, : ]], dim=2)
below = torch.cat([mask[:, :, 1:, : ], mask[:, :, -1:, : ]], dim=2)
left = torch.cat([mask[:, :, :, :1], mask[:, :, :, :-1]], dim=3)
right = torch.cat([mask[:, :, :, 1: ], mask[:, :, :, -1: ]], dim=3)
_h = (left - right) * scale
_v = (above - below) * scale
_v = torch.where(_v < y_limit_min, y_limit_min, _v)
_v = torch.where(_v > y_limit_min, y_limit_max, _v)
remainder_h = _h % 1
remainder_v = _v % 1
h = torch.floor(_h) + x
v = torch.floor(_v + y)
remainder_h[(h >= x_limit_max) | (h < x_limit_min)] = 1e-4
h = torch.clamp(h, x_limit_min, x_limit_max)
h = h.long()
v = v.long()
"""
s00 = src[n, c, v, h]
s01 = src[n, c, v, h + 1]
s10 = src[n, c, v + 1, h]
s11 = src[n, c, v + 1, h + 1]
"""
src = torch.flatten(src, start_dim=2) # shape: (N, C, H * W)
s00_indices = (v * W + h).flatten() # shape: (1, 1, H * W)
s00 = torch.index_select(src, dim=2, index=s00_indices).reshape(-1,C,H,W)
s01_indices = (v * W + h + 1).flatten() # shape: (1, 1, H * W)
s01 = torch.index_select(src, dim=2, index=s01_indices).reshape(-1,C,H,W)
s10_indices = ((v + 1) * W + h).flatten() # shape: (1, 1, H * W)
s10 = torch.index_select(src, dim=2, index=s10_indices).reshape(-1,C,H,W)
s11_indices = ((h + 1) * W + (v + 1)).flatten() # shape: (1, 1, H * W)
s11 = torch.index_select(src, dim=2, index=s11_indices).reshape(-1,C,H,W)
s0 = s00 * (1 - remainder_h) + s01 * remainder_h + 0.5
s1 = s10 * (1 - remainder_h) + s11 * remainder_h + 0.5
s = s0 * (1 - remainder_v) + s1 * remainder_v + 0.5
s = torch.clamp(s, 0.0, pixel_max)
dst = s.view(N, C, H, W)
return dst
Deprecated
https://gitlab.com/sr11/gist/blob/master/awarp/awarp.py
def awarp(src, mask):
N,C,H,W = src.shape
scale = (8 * 256) / 65536
pixel_max = 1.0
y = torch.arange(H).view(1, 1, -1, 1).float()
x = torch.arange(W).view(1, 1, 1, -1).float()
# coord = torch.arange(H*W).view(1, 1, H, W)
# x_limit_min = torch.zeros_like(x)
# x_limit_max = torch.full_like(x, W) - 2
x_limit_min = 0
x_limit_max = W - 2
y_limit_min = -y
y_limit_max = (H - 2) - y
above = torch.cat([mask[:, :, :1, : ], mask[:, :, :-1, : ]], dim=2)
below = torch.cat([mask[:, :, 1:, : ], mask[:, :, -1:, : ]], dim=2)
left = torch.cat([mask[:, :, :, :1], mask[:, :, :, :-1]], dim=3)
right = torch.cat([mask[:, :, :, 1: ], mask[:, :, :, -1: ]], dim=3)
_h = (left - right) * scale
_v = (above - below) * scale
_v = torch.where(_v < y_limit_min, y_limit_min, _v)
_v = torch.where(_v > y_limit_min, y_limit_max, _v)
remainder_h = _h % 1
remainder_v = _v % 1
h = torch.floor(_h) + x
v = torch.floor(_v + y)
remainder_h[(h >= x_limit_max) | (h < x_limit_min)] = 1e-4
h = torch.clamp(h, x_limit_min, x_limit_max)
h = h.long()
v = v.long()
"""
s00 = src[n, c, v, h]
s01 = src[n, c, v, h + 1]
s10 = src[n, c, v + 1, h]
s11 = src[n, c, v + 1, h + 1]
"""
src = torch.flatten(src, start_dim=2) # shape: (N, C, H * W)
# please also try replacing ".expand()" with ".repeat()"
s00_indices = (v * W + h).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
s00 = torch.gather(src, dim=2, index=s00_indices).reshape(-1,C,H,W)
s01_indices = (v * W + h + 1).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
s01 = torch.gather(src, dim=2, index=s01_indices).reshape(-1,C,H,W)
s10_indices = ((v + 1) * W + h).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
s10 = torch.gather(src, dim=2, index=s10_indices).reshape(-1,C,H,W)
s11_indices = ((h + 1) * W + (v + 1)).flatten(start_dim=2).expand(-1, C, -1) # shape: (N, C, H * W)
s11 = torch.gather(src, dim=2, index=s11_indices).expand_as(src)
s0 = torch.lerp(s00, s01, remainder_h) # + 0.5
s1 = torch.lerp(s10, s11, remainder_h) # + 0.5
s = torch.lerp(s0, s1, remainder_v) # + 0.5
s = torch.clamp(s, 0.0, pixel_max)
dst = s.view(N, C, H, W)
return dst
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment