Skip to content

Instantly share code, notes, and snippets.

@kice
Created December 9, 2019 19:34
Show Gist options
  • Save kice/ffaae8c68949a3be221d1f1a8b0f7a8d to your computer and use it in GitHub Desktop.
Save kice/ffaae8c68949a3be221d1f1a8b0f7a8d to your computer and use it in GitHub Desktop.
// SMAGL is 0 or 2
// PixelType is uint8_t
static void warp_c(const uint8_t *srcp8, const uint8_t *edgep8, uint8_t *dstp8, int src_stride, int edge_stride, int dst_stride, int width, int height, int depth, int bits_per_sample)
{
const uint8_t *srcp = (const uint8_t *)srcp8;
const uint8_t *edgep = (const uint8_t *)edgep8;
uint8_t *dstp = (uint8_t *)dstp8;
src_stride /= sizeof(uint8_t);
edge_stride /= sizeof(uint8_t);
dst_stride /= sizeof(uint8_t);
int pixel_max = (1 << bits_per_sample) - 1;
const int x_limit_min = 0;
const int x_limit_max = (width - 1);
float scale = (depth << 8) / 65536.0f;
for (int y = 0; y < height; y++) {
float y_limit_min = -y;
float y_limit_max = (height - y - 1) - 1e-2f;
for (int x = 0; x < width; x++) {
// calculate displacement
int above, below;
if (y == 0)
above = edgep[x];
else
above = edgep[-edge_stride + x];
if (y == height - 1)
below = edgep[x];
else
below = edgep[edge_stride + x];
int left, right;
if (x == 0)
left = edgep[x];
else
left = edgep[x - 1];
if (x == width - 1)
right = edgep[x];
else
right = edgep[x + 1];
float _h = (left - right) * scale;
float _v = (above - below) * scale;
_v = std::clamp(_v, y_limit_min, y_limit_max);
float remainder_h = fmod(_h, 1.0);
remainder_h = remainder_h < 0.0 ? 1.0 + remainder_h : remainder_h;
float remainder_v = fmod(_v, 1.0);
remainder_v = remainder_v < 0.0 ? 1.0 + remainder_v : remainder_v;
int h = floor(_h) + x, v = floor(_v);
remainder_h = (x_limit_max > h) && !(x_limit_min > h) ? remainder_h : 0;
h = std::min(h, x_limit_max);
h = std::max(h, x_limit_min);
// h and v contain the displacement now.
int s00 = srcp[v * src_stride + h];
int s01 = srcp[v * src_stride + h + 1];
int s10 = srcp[(v + 1) * src_stride + h];
int s11 = srcp[(v + 1) * src_stride + h + 1];
float s0 = s00 * (1.0 - remainder_h) + s01 * remainder_h + 0.5;
float s1 = s10 * (1.0 - remainder_h) + s11 * remainder_h + 0.5;
float s = s0 * (1.0 - remainder_v) + s1 * remainder_v + 0.5;
int val = nearbyint(s); // Use floor to match original
dstp[x] = std::min(std::max(val, 0), pixel_max);
}
srcp += src_stride;
edgep += edge_stride;
dstp += dst_stride;
}
}
@kice
Copy link
Author

kice commented Dec 10, 2019

def awarp(src, mask):
    N,C,H,W = src.shape
    
    scale = (8 * 256) / 65536

    pixel_max = 1.0

    y = torch.arange(H).view(1, 1, -1, 1).float()
    x = torch.arange(W).view(1, 1, 1, -1).float()

    coord = torch.arange(H*W).view(1, 1, H, W)

    # x_limit_min = torch.zeros_like(x)
    # x_limit_max = torch.full_like(x, W) - 2

    x_limit_min = 0
    x_limit_max = W - 2

    y_limit_min = -y
    y_limit_max = (H - 2) - y

    above = torch.cat([mask[:, :,  :1, : ], mask[:, :,   :-1, :  ]], dim=2)
    below = torch.cat([mask[:, :, 1:,  : ], mask[:, :, -1:,   :  ]], dim=2)
    left  = torch.cat([mask[:, :,  :,  :1], mask[:, :,   :,   :-1]], dim=3)
    right = torch.cat([mask[:, :,  :, 1: ], mask[:, :,   :, -1:  ]], dim=3)

    _h = (left - right) * scale
    _v = (above - below) * scale

    _v = torch.where(_v < y_limit_min, y_limit_min, _v)
    _v = torch.where(_v > y_limit_min, y_limit_max, _v)

    remainder_h = _h % 1
    remainder_v = _v % 1

    h = torch.floor(_h) + x
    v = torch.floor(_v + y)

    remainder_h[(h >= x_limit_max) | (h < x_limit_min)] = 1e-4

    h = torch.clamp(h, x_limit_min, x_limit_max)

    h = h.long()
    v = v.long()

    """
    s00 = src[n, c, v, h]
    s01 = src[n, c, v, h + 1]
    s10 = src[n, c, v + 1, h]
    s11 = src[n, c, v + 1, h + 1]
    """
    src = torch.flatten(src, start_dim=2) # shape: (N, C, H * W)

    s00_indices = (v * W + h).flatten() # shape: (1, 1, H * W)
    s00 = torch.index_select(src, dim=2, index=s00_indices).reshape(-1,C,H,W)
    s01_indices = (v * W + h + 1).flatten() # shape: (1, 1, H * W)
    s01 = torch.index_select(src, dim=2, index=s01_indices).reshape(-1,C,H,W)
    s10_indices = ((v + 1) * W + h).flatten() # shape: (1, 1, H * W)
    s10 = torch.index_select(src, dim=2, index=s10_indices).reshape(-1,C,H,W)
    s11_indices = ((h + 1) * W + (v + 1)).flatten() # shape: (1, 1, H * W)
    s11 = torch.index_select(src, dim=2, index=s11_indices).reshape(-1,C,H,W)

    s0 = s00 * (1 - remainder_h) + s01 * remainder_h + 0.5
    s1 = s10 * (1 - remainder_h) + s11 * remainder_h + 0.5
    s  =  s0 * (1 - remainder_v) +  s1 * remainder_v + 0.5

    s = torch.clamp(s, 0.0, pixel_max)
    dst = s.view(N, C, H, W)
    return dst

@WolframRhodium
Copy link

WolframRhodium commented Dec 11, 2019

Deprecated

https://gitlab.com/sr11/gist/blob/master/awarp/awarp.py

def awarp(src, mask):
    N,C,H,W = src.shape
    
    scale = (8 * 256) / 65536

    pixel_max = 1.0

    y = torch.arange(H).view(1, 1, -1, 1).float()
    x = torch.arange(W).view(1, 1, 1, -1).float()

    # coord = torch.arange(H*W).view(1, 1, H, W)

    # x_limit_min = torch.zeros_like(x)
    # x_limit_max = torch.full_like(x, W) - 2

    x_limit_min = 0
    x_limit_max = W - 2

    y_limit_min = -y
    y_limit_max = (H - 2) - y

    above = torch.cat([mask[:, :,  :1, : ], mask[:, :,   :-1, :  ]], dim=2)
    below = torch.cat([mask[:, :, 1:,  : ], mask[:, :, -1:,   :  ]], dim=2)
    left  = torch.cat([mask[:, :,  :,  :1], mask[:, :,   :,   :-1]], dim=3)
    right = torch.cat([mask[:, :,  :, 1: ], mask[:, :,   :, -1:  ]], dim=3)

    _h = (left - right) * scale
    _v = (above - below) * scale

    _v = torch.where(_v < y_limit_min, y_limit_min, _v)
    _v = torch.where(_v > y_limit_min, y_limit_max, _v)

    remainder_h = _h % 1
    remainder_v = _v % 1

    h = torch.floor(_h) + x
    v = torch.floor(_v + y)

    remainder_h[(h >= x_limit_max) | (h < x_limit_min)] = 1e-4

    h = torch.clamp(h, x_limit_min, x_limit_max)

    h = h.long()
    v = v.long()

    """
    s00 = src[n, c, v, h]
    s01 = src[n, c, v, h + 1]
    s10 = src[n, c, v + 1, h]
    s11 = src[n, c, v + 1, h + 1]
    """
    src = torch.flatten(src, start_dim=2) # shape: (N, C, H * W)

    # please also try replacing ".expand()" with ".repeat()"
    s00_indices = (v * W + h).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
    s00 = torch.gather(src, dim=2, index=s00_indices).reshape(-1,C,H,W)
    s01_indices = (v * W + h + 1).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
    s01 = torch.gather(src, dim=2, index=s01_indices).reshape(-1,C,H,W)
    s10_indices = ((v + 1) * W + h).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
    s10 = torch.gather(src, dim=2, index=s10_indices).reshape(-1,C,H,W)
    s11_indices = ((h + 1) * W + (v + 1)).flatten(start_dim=2).expand(-1, C, -1) # shape: (N, C, H * W)
    s11 = torch.gather(src, dim=2, index=s11_indices).expand_as(src)

    s0 = torch.lerp(s00, s01, remainder_h) # + 0.5
    s1 = torch.lerp(s10, s11, remainder_h) # + 0.5
    s = torch.lerp(s0, s1, remainder_v) # + 0.5

    s = torch.clamp(s, 0.0, pixel_max)
    dst = s.view(N, C, H, W)
    return dst

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment