Created
December 9, 2019 19:34
-
-
Save kice/ffaae8c68949a3be221d1f1a8b0f7a8d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// SMAGL is 0 or 2 | |
// PixelType is uint8_t | |
static void warp_c(const uint8_t *srcp8, const uint8_t *edgep8, uint8_t *dstp8, int src_stride, int edge_stride, int dst_stride, int width, int height, int depth, int bits_per_sample) | |
{ | |
const uint8_t *srcp = (const uint8_t *)srcp8; | |
const uint8_t *edgep = (const uint8_t *)edgep8; | |
uint8_t *dstp = (uint8_t *)dstp8; | |
src_stride /= sizeof(uint8_t); | |
edge_stride /= sizeof(uint8_t); | |
dst_stride /= sizeof(uint8_t); | |
int pixel_max = (1 << bits_per_sample) - 1; | |
const int x_limit_min = 0; | |
const int x_limit_max = (width - 1); | |
float scale = (depth << 8) / 65536.0f; | |
for (int y = 0; y < height; y++) { | |
float y_limit_min = -y; | |
float y_limit_max = (height - y - 1) - 1e-2f; | |
for (int x = 0; x < width; x++) { | |
// calculate displacement | |
int above, below; | |
if (y == 0) | |
above = edgep[x]; | |
else | |
above = edgep[-edge_stride + x]; | |
if (y == height - 1) | |
below = edgep[x]; | |
else | |
below = edgep[edge_stride + x]; | |
int left, right; | |
if (x == 0) | |
left = edgep[x]; | |
else | |
left = edgep[x - 1]; | |
if (x == width - 1) | |
right = edgep[x]; | |
else | |
right = edgep[x + 1]; | |
float _h = (left - right) * scale; | |
float _v = (above - below) * scale; | |
_v = std::clamp(_v, y_limit_min, y_limit_max); | |
float remainder_h = fmod(_h, 1.0); | |
remainder_h = remainder_h < 0.0 ? 1.0 + remainder_h : remainder_h; | |
float remainder_v = fmod(_v, 1.0); | |
remainder_v = remainder_v < 0.0 ? 1.0 + remainder_v : remainder_v; | |
int h = floor(_h) + x, v = floor(_v); | |
remainder_h = (x_limit_max > h) && !(x_limit_min > h) ? remainder_h : 0; | |
h = std::min(h, x_limit_max); | |
h = std::max(h, x_limit_min); | |
// h and v contain the displacement now. | |
int s00 = srcp[v * src_stride + h]; | |
int s01 = srcp[v * src_stride + h + 1]; | |
int s10 = srcp[(v + 1) * src_stride + h]; | |
int s11 = srcp[(v + 1) * src_stride + h + 1]; | |
float s0 = s00 * (1.0 - remainder_h) + s01 * remainder_h + 0.5; | |
float s1 = s10 * (1.0 - remainder_h) + s11 * remainder_h + 0.5; | |
float s = s0 * (1.0 - remainder_v) + s1 * remainder_v + 0.5; | |
int val = nearbyint(s); // Use floor to match original | |
dstp[x] = std::min(std::max(val, 0), pixel_max); | |
} | |
srcp += src_stride; | |
edgep += edge_stride; | |
dstp += dst_stride; | |
} | |
} |
CALL
dim3 threadsPerBlock(32, 32);
assert(threadsPerBlock.z == 1 && threadsPerBlock.x * threadsPerBlock.y <= 1024);
dim3 numBlocks(
(width + threadsPerBlock.x - 1) / threadsPerBlock.x,
(height + threadsPerBlock.y - 1) / threadsPerBlock.y
);
warp_cuda<<<threadsPerBlock, numBlocks>>>(...);
code
__global__
static void warp_cuda(const uint8_t * __restrict__ srcp8, const uint8_t * __restrict__ edgep8, uint8_t * __restrict__ dstp8, int src_stride, int edge_stride, int dst_stride, int width, int height, int depth, int bits_per_sample)
{
const int x = blockIdx.x * blockDim.x + threadIdx.x;
const int y = blockIdx.y * blockDim.y + threadIdx.y;
if (y < height && x < width) {
const uint8_t *srcp = (const uint8_t *)srcp8;
const uint8_t *edgep = (const uint8_t *)edgep8;
uint8_t *dstp = (uint8_t *)dstp8;
src_stride /= sizeof(uint8_t);
edge_stride /= sizeof(uint8_t);
dst_stride /= sizeof(uint8_t);
int pixel_max = (1 << bits_per_sample) - 1;
const int x_limit_min = 0;
const int x_limit_max = (width - 1);
float scale = (depth << 8) / 65536.0f;
// for (int y = 0; y < height; y++) {
{
float y_limit_min = -y;
float y_limit_max = (height - y - 1) - 1e-2f;
// for (int x = 0; x < width; x++) {
{
srcp += src_stride * y;
edgep += edge_stride * y;
dstp += dstp_stride * y;
// calculate displacement
int above, below;
if (y == 0)
above = edgep[x];
else
above = edgep[-edge_stride + x];
if (y == height - 1)
below = edgep[x];
else
below = edgep[edge_stride + x];
int left, right;
if (x == 0)
left = edgep[x];
else
left = edgep[x - 1];
if (x == width - 1)
right = edgep[x];
else
right = edgep[x + 1];
float _h = (left - right) * scale;
float _v = (above - below) * scale;
// _v = std::clamp(_v, y_limit_min, y_limit_max);
_v = min(max(_v, y_limit_min), y_limit_max);
// float remainder_h = fmodf(_h, 1.0f);
float remainder_h = _h - nearbyintf(_h);
remainder_h = remainder_h < 0.0f ? 1.0f + remainder_h : remainder_h;
// float remainder_v = fmodf(_v, 1.0f);
float remainder_v = _v - nearbyintf(_v);
remainder_v = remainder_v < 0.0f ? 1.0f + remainder_v : remainder_v;
// int h = floor(_h) + x, v = floor(_v);
int h = __float2int_rd(_h) + x, v = __float2int_rd(_v);
remainder_h = (x_limit_max > h) && !(x_limit_min > h) ? remainder_h : 0f;
h = min(h, x_limit_max);
h = max(h, x_limit_min);
// h and v contain the displacement now.
int s00 = srcp[v * src_stride + h];
int s01 = srcp[v * src_stride + h + 1];
int s10 = srcp[(v + 1) * src_stride + h];
int s11 = srcp[(v + 1) * src_stride + h + 1];
float s0 = s00 * (1.0f - remainder_h) + s01 * remainder_h + 0.5f;
float s1 = s10 * (1.0f - remainder_h) + s11 * remainder_h + 0.5f;
float s = s0 * (1.0f - remainder_v) + s1 * remainder_v + 0.5f;
// int val = nearbyintf(s);
int val = __float2int_rn(s); // Use floor to match original
dstp[x] = min(max(val, 0), pixel_max);
}
// srcp += src_stride;
// edgep += edge_stride;
// dstp += dst_stride;
}
}
import torch
def rgb_warp(src, mask):
N, C, H, W = src.shape
device = src.device
# dst = src.clone()
scale = (8 * 256) / 65536
pixel_max = 1.0
x_limit_min = 0
x_limit_max = W - 2
# for n in range(N):
if True:
# for c in range(C):
if True:
# for y in range(H):
y = torch.arange(H).view(1, 1, -1, 1)
if True:
y_limit_min = -y
y_limit_max = (H - 2) - y
# for x in range(W):
x = torch.arange(W).view(1, 1, 1, -1)
if True:
# above = mask[n, c, 0, x] if y == 0 else mask[n, c, y - 1, x]
above = torch.cat([mask[:, :, :1, :], mask[:, :, :-1, :]], dim=2)
# below = mask[n, c, -1, x] if y == H - 1 else mask[n, c, y + 1, x]
below = torch.cat([mask[:, :, 1:, :], mask[:, :, -1:, :]], dim=2)
# left = mask[n, c, y, x] if x == 0 else mask[n, c, y, x - 1]
left = torch.cat([mask[:, :, :, :1], mask[:, :, :, :-1]], dim=3)
# right = mask[n, c, y, -1] if x == W - 1 else mask[n, c, y, x + 1]
right = torch.cat([mask[:, :, :, 1:], mask[:, :, :, -1:]], dim=3)
_h = (left - right) * scale
_v = (above - below) * scale
# _v = torch.clamp(_v, y_limit_min, y_limit_max)
_v = torch.where(_v < y_limit_min, y_limit_min, _v)
_v = torch.where(_v > y_limit_min, y_limit_max, _v)
remainder_h = _h % 1
remainder_v = _v % 1
h = (torch.floor(_h) + x)
v = torch.floor(_v)
# remainder_h = remainder_h if (x_limit_max > h) and not (x_limit_min > h) else torch.zeros(1)
remainder_h = torch.where((x_limit_max > h) & (x_limit_min <= h), remainder_h, torch.zeros(1))
# h = torch.clamp(h, x_limit_min, x_limit_max)
_h = torch.where(h < x_limit_min, x_limit_min, h)
_h = torch.where(_h > x_limit_max, x_limit_max, _h)
h = h.long()
v = v.long()
"""
s00 = src[n, c, v, h]
s01 = src[n, c, v, h + 1]
s10 = src[n, c, v + 1, h]
s11 = src[n, c, v + 1, h + 1]
"""
src = torch.flatten(src, start_dim=2) # shape: (N, C, H * W)
s00_indices = (v * W + h).flatten() # shape: (H * W)
s00 = torch.index_select(src, dim=2, index=s00_indices)
s01_indices = (v * W + h + 1).flatten() # shape: (H * W)
s01 = torch.index_select(src, dim=2, index=s01_indices)
s10_indices = ((v + 1) * W + h).flatten()[0, 0] # shape: (H * W)
s10 = torch.index_select(src, dim=2, index=s10_indices)
s11_indices = ((h + 1) * W + (v + 1)).flatten() # shape: (H * W)
s11 = torch.index_select(src, dim=2, index=s11_indices)
# s0 = s00 * (1 - remainder_h) + s01 * remainder_h + 0.5
s0 = torch.lerp(s00, s01, remainder_h) + 0.5
# s1 = s10 * (1 - remainder_h) + s11 * remainder_h + 0.5
s1 = torch.lerp(s10, s11, remainder_h) + 0.5
# s = s0 * (1 - remainder_v) + s1 * remainder_v + 0.5
s = torch.lerp(s0, s1, remainder_v) + 0.5
# s = torch.clamp(s, 0.0, 1.0).view(1)
s = torch.where(s < 0.0, 0.0, s)
s = torch.where(s > 1.0, 1.0, s)
#dst[n, c, y, x:x + 1] = s
dst = s.view(N, C, H, W)
return dst
def awarp(src, mask):
N,C,H,W = src.shape
scale = (8 * 256) / 65536
pixel_max = 1.0
y = torch.arange(H).view(1, 1, -1, 1).float()
x = torch.arange(W).view(1, 1, 1, -1).float()
coord = torch.arange(H*W).view(1, 1, H, W)
# x_limit_min = torch.zeros_like(x)
# x_limit_max = torch.full_like(x, W) - 2
x_limit_min = 0
x_limit_max = W - 2
y_limit_min = -y
y_limit_max = (H - 2) - y
above = torch.cat([mask[:, :, :1, : ], mask[:, :, :-1, : ]], dim=2)
below = torch.cat([mask[:, :, 1:, : ], mask[:, :, -1:, : ]], dim=2)
left = torch.cat([mask[:, :, :, :1], mask[:, :, :, :-1]], dim=3)
right = torch.cat([mask[:, :, :, 1: ], mask[:, :, :, -1: ]], dim=3)
_h = (left - right) * scale
_v = (above - below) * scale
_v = torch.where(_v < y_limit_min, y_limit_min, _v)
_v = torch.where(_v > y_limit_min, y_limit_max, _v)
remainder_h = _h % 1
remainder_v = _v % 1
h = torch.floor(_h) + x
v = torch.floor(_v + y)
remainder_h[(h >= x_limit_max) | (h < x_limit_min)] = 1e-4
h = torch.clamp(h, x_limit_min, x_limit_max)
h = h.long()
v = v.long()
"""
s00 = src[n, c, v, h]
s01 = src[n, c, v, h + 1]
s10 = src[n, c, v + 1, h]
s11 = src[n, c, v + 1, h + 1]
"""
src = torch.flatten(src, start_dim=2) # shape: (N, C, H * W)
s00_indices = (v * W + h).flatten() # shape: (1, 1, H * W)
s00 = torch.index_select(src, dim=2, index=s00_indices).reshape(-1,C,H,W)
s01_indices = (v * W + h + 1).flatten() # shape: (1, 1, H * W)
s01 = torch.index_select(src, dim=2, index=s01_indices).reshape(-1,C,H,W)
s10_indices = ((v + 1) * W + h).flatten() # shape: (1, 1, H * W)
s10 = torch.index_select(src, dim=2, index=s10_indices).reshape(-1,C,H,W)
s11_indices = ((h + 1) * W + (v + 1)).flatten() # shape: (1, 1, H * W)
s11 = torch.index_select(src, dim=2, index=s11_indices).reshape(-1,C,H,W)
s0 = s00 * (1 - remainder_h) + s01 * remainder_h + 0.5
s1 = s10 * (1 - remainder_h) + s11 * remainder_h + 0.5
s = s0 * (1 - remainder_v) + s1 * remainder_v + 0.5
s = torch.clamp(s, 0.0, pixel_max)
dst = s.view(N, C, H, W)
return dst
Deprecated
https://gitlab.com/sr11/gist/blob/master/awarp/awarp.py
def awarp(src, mask):
N,C,H,W = src.shape
scale = (8 * 256) / 65536
pixel_max = 1.0
y = torch.arange(H).view(1, 1, -1, 1).float()
x = torch.arange(W).view(1, 1, 1, -1).float()
# coord = torch.arange(H*W).view(1, 1, H, W)
# x_limit_min = torch.zeros_like(x)
# x_limit_max = torch.full_like(x, W) - 2
x_limit_min = 0
x_limit_max = W - 2
y_limit_min = -y
y_limit_max = (H - 2) - y
above = torch.cat([mask[:, :, :1, : ], mask[:, :, :-1, : ]], dim=2)
below = torch.cat([mask[:, :, 1:, : ], mask[:, :, -1:, : ]], dim=2)
left = torch.cat([mask[:, :, :, :1], mask[:, :, :, :-1]], dim=3)
right = torch.cat([mask[:, :, :, 1: ], mask[:, :, :, -1: ]], dim=3)
_h = (left - right) * scale
_v = (above - below) * scale
_v = torch.where(_v < y_limit_min, y_limit_min, _v)
_v = torch.where(_v > y_limit_min, y_limit_max, _v)
remainder_h = _h % 1
remainder_v = _v % 1
h = torch.floor(_h) + x
v = torch.floor(_v + y)
remainder_h[(h >= x_limit_max) | (h < x_limit_min)] = 1e-4
h = torch.clamp(h, x_limit_min, x_limit_max)
h = h.long()
v = v.long()
"""
s00 = src[n, c, v, h]
s01 = src[n, c, v, h + 1]
s10 = src[n, c, v + 1, h]
s11 = src[n, c, v + 1, h + 1]
"""
src = torch.flatten(src, start_dim=2) # shape: (N, C, H * W)
# please also try replacing ".expand()" with ".repeat()"
s00_indices = (v * W + h).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
s00 = torch.gather(src, dim=2, index=s00_indices).reshape(-1,C,H,W)
s01_indices = (v * W + h + 1).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
s01 = torch.gather(src, dim=2, index=s01_indices).reshape(-1,C,H,W)
s10_indices = ((v + 1) * W + h).flatten(start_dim=2).expand_as(src) # shape: (N, C, H * W)
s10 = torch.gather(src, dim=2, index=s10_indices).reshape(-1,C,H,W)
s11_indices = ((h + 1) * W + (v + 1)).flatten(start_dim=2).expand(-1, C, -1) # shape: (N, C, H * W)
s11 = torch.gather(src, dim=2, index=s11_indices).expand_as(src)
s0 = torch.lerp(s00, s01, remainder_h) # + 0.5
s1 = torch.lerp(s10, s11, remainder_h) # + 0.5
s = torch.lerp(s0, s1, remainder_v) # + 0.5
s = torch.clamp(s, 0.0, pixel_max)
dst = s.view(N, C, H, W)
return dst
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
From: https://github.com/dubhater/vapoursynth-awarpsharp2/blob/886d4b73ff1406e3be171bc141ff02c68addd774/src/awarpsharp2.cpp#L445