Last active
September 16, 2024 06:09
-
-
Save 99991/c9de403245bb71288001452045df0c6c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import shutil | |
# Remove cache if it exists | |
try: | |
shutil.rmtree("__pycache__") | |
except FileNotFoundError: | |
pass | |
try: | |
with open("/proc/cpuinfo") as f: | |
print("CPU info:") | |
print("-" * 60) | |
print(f.read().strip().split("\n\n")[-1]) | |
print("-" * 60) | |
print("\n\n") | |
except FileNotFoundError as e: | |
print(e) | |
import time | |
class Timer: | |
def __init__(self): | |
self.start() | |
def start(self): | |
self.t = time.perf_counter() | |
def stop(self, msg): | |
elapsed_time = time.perf_counter() - self.t | |
print(f"{elapsed_time:.3f} sec - {msg}") | |
self.t = time.perf_counter() | |
t = Timer() | |
import numpy as np | |
t.stop("import numpy") | |
from numba import njit, prange, pndindex | |
t.stop("import numba") | |
@njit("i8(i8[:], i8[:], i8[:], i8[:], i8[:], f4[:, :, :], f4[:], f4[:, :], i8[:], i8)", cache=True, nogil=True) | |
def _make_tree( | |
i0_inds, | |
i1_inds, | |
less_inds, | |
more_inds, | |
split_dims, | |
bounds, | |
split_values, | |
points, | |
indices, | |
min_leaf_size, | |
): | |
dimension = points.shape[1] | |
# Expect log2(len(points) / min_leaf_size) depth, 1000 should be plenty | |
stack = np.empty(1000, np.int64) | |
stack_size = 0 | |
n_nodes = 0 | |
# min_leaf_size <= leaf_node_size < max_leaf_size | |
max_leaf_size = 2 * min_leaf_size | |
# Push i0, i1, i_node | |
stack[stack_size] = 0 | |
stack_size += 1 | |
stack[stack_size] = points.shape[0] | |
stack_size += 1 | |
stack[stack_size] = n_nodes | |
n_nodes += 1 | |
stack_size += 1 | |
# While there are more tree nodes to process recursively | |
while stack_size > 0: | |
stack_size -= 1 | |
i_node = stack[stack_size] | |
stack_size -= 1 | |
i1 = stack[stack_size] | |
stack_size -= 1 | |
i0 = stack[stack_size] | |
lo = bounds[i_node, 0] | |
hi = bounds[i_node, 1] | |
for d in range(dimension): | |
lo[d] = points[i0, d] | |
hi[d] = points[i0, d] | |
# Find lower and upper bounds of points for each dimension | |
for i in range(i0 + 1, i1): | |
for d in range(dimension): | |
lo[d] = min(lo[d], points[i, d]) | |
hi[d] = max(hi[d], points[i, d]) | |
# Done if node is small | |
if i1 - i0 <= max_leaf_size: | |
i0_inds[i_node] = i0 | |
i1_inds[i_node] = i1 | |
less_inds[i_node] = -1 | |
more_inds[i_node] = -1 | |
split_dims[i_node] = -1 | |
split_values[i_node] = 0.0 | |
else: | |
# Split on largest dimension | |
lengths = hi - lo | |
split_dim = np.argmax(lengths) | |
split_value = lo[split_dim] + 0.5 * lengths[split_dim] | |
# Partition i0:i1 range into points where points[i, split_dim] < split_value | |
i = i0 | |
j = i1 - 1 | |
while i < j: | |
while i < j and points[i, split_dim] < split_value: | |
i += 1 | |
while i < j and points[j, split_dim] >= split_value: | |
j -= 1 | |
# Swap points | |
if i < j: | |
for d in range(dimension): | |
temp = points[i, d] | |
points[i, d] = points[j, d] | |
points[j, d] = temp | |
temp_i_node = indices[i] | |
indices[i] = indices[j] | |
indices[j] = temp_i_node | |
if points[i, split_dim] < split_value: | |
i += 1 | |
i_split = i | |
# Now it holds that: | |
# for i in range(i0, i_split): assert(points[i, split_dim] < split_value) | |
# for i in range(i_split, i1): assert(points[i, split_dim] >= split_value) | |
# Ensure that each node has at least min_leaf_size children | |
i_split = max(i0 + min_leaf_size, min(i1 - min_leaf_size, i_split)) | |
less = n_nodes | |
n_nodes += 1 | |
more = n_nodes | |
n_nodes += 1 | |
# push i0, i_split, less | |
stack[stack_size] = i0 | |
stack_size += 1 | |
stack[stack_size] = i_split | |
stack_size += 1 | |
stack[stack_size] = less | |
stack_size += 1 | |
# push i_split, i1, more | |
stack[stack_size] = i_split | |
stack_size += 1 | |
stack[stack_size] = i1 | |
stack_size += 1 | |
stack[stack_size] = more | |
stack_size += 1 | |
i0_inds[i_node] = i0 | |
i1_inds[i_node] = i1 | |
less_inds[i_node] = less | |
more_inds[i_node] = more | |
split_dims[i_node] = split_dim | |
split_values[i_node] = split_value | |
return n_nodes | |
t.stop("_make_tree") | |
@njit("void(i8[:], i8[:], i8[:], i8[:], i8[:], f4[:, :, :], f4[:], f4[:, :], f4[:, :], i8[:, :], f4[:, :], i8)", cache=True, nogil=True, parallel=True) | |
def _find_knn( | |
i0_inds, | |
i1_inds, | |
less_inds, | |
more_inds, | |
split_dims, | |
bounds, | |
split_values, | |
points, | |
query_points, | |
out_indices, | |
out_distances, | |
k, | |
): | |
dimension = points.shape[1] | |
# For each query point | |
for i_query in prange(query_points.shape[0]): | |
query_point = query_points[i_query] | |
distances = out_distances[i_query] | |
indices = out_indices[i_query] | |
# Expect log2(len(points) / min_leaf_size) depth, 1000 should be plenty | |
stack = np.empty(1000, np.int64) | |
n_neighbors = 0 | |
stack[0] = 0 | |
stack_size = 1 | |
# While there are nodes to visit | |
while stack_size > 0: | |
stack_size -= 1 | |
i_node = stack[stack_size] | |
# If we found more neighbors than we need | |
if n_neighbors >= k: | |
# Calculate distance to bounding box of node | |
dist = 0.0 | |
for d in range(dimension): | |
p = query_point[d] | |
dp = p - max(bounds[i_node, 0, d], min(bounds[i_node, 1, d], p)) | |
dist += dp * dp | |
# Do nothing with this node if all points we have found so far | |
# are closer than the bounding box of the node. | |
if dist > distances[n_neighbors - 1]: | |
continue | |
# If leaf node | |
if split_dims[i_node] == -1: | |
# For each point in leaf node | |
for i in range(i0_inds[i_node], i1_inds[i_node]): | |
# Calculate distance between query point and point in node | |
distance = 0.0 | |
for d in range(dimension): | |
dd = query_point[d] - points[i, d] | |
distance += dd * dd | |
# Find insert position | |
insert_pos = n_neighbors | |
for j in range(n_neighbors - 1, -1, -1): | |
if distances[j] > distance: | |
insert_pos = j | |
# Insert found point in a sorted order | |
if insert_pos < k: | |
# Move [insert_pos:k-1] one to the right to make space | |
for j in range(min(n_neighbors, k - 1), insert_pos, -1): | |
indices[j] = indices[j - 1] | |
distances[j] = distances[j - 1] | |
# Insert new neighbor | |
indices[insert_pos] = i | |
distances[insert_pos] = distance | |
n_neighbors = min(n_neighbors + 1, k) | |
else: | |
# Descent to child nodes | |
less = less_inds[i_node] | |
more = more_inds[i_node] | |
split_dim = split_dims[i_node] | |
# First, visit child in same bounding box as query point | |
if query_point[split_dim] < split_values[i_node]: | |
stack[stack_size] = more | |
stack_size += 1 | |
stack[stack_size] = less | |
stack_size += 1 | |
else: | |
# Next, visit other child | |
stack[stack_size] = less | |
stack_size += 1 | |
stack[stack_size] = more | |
stack_size += 1 | |
# Workaround for https://github.com/numba/numba/issues/5156 | |
stack_size += 0 | |
t.stop("_find_knn") | |
@njit("(f8[:],)", cache=True, nogil=True) | |
def _propagate_1d_first_pass(d): | |
n = len(d) | |
for i1 in range(1, n): | |
i2 = i1 - 1 | |
d[i1] = min(d[i1], d[i2] + 1) | |
for i1 in range(n - 2, -1, -1): | |
i2 = i1 + 1 | |
d[i1] = min(d[i1], d[i2] + 1) | |
t.stop("_propagate_1d_first_pass") | |
@njit("(f8[:], i4[:], f8[:], f8[:])", cache=True, nogil=True) | |
def _propagate_1d(d, v, z, f): | |
nx = len(d) | |
k = -1 | |
s = -np.inf | |
for x in range(nx): | |
d_yx = d[x] | |
if d_yx == np.inf: | |
continue | |
fx = d_yx * d_yx | |
f[x] = fx | |
while k >= 0: | |
vk = v[k] | |
s = 0.5 * (fx + x * x - f[vk] - vk * vk) / (x - vk) | |
if s > z[k]: | |
break | |
k -= 1 | |
k += 1 | |
v[k] = x | |
z[k] = s | |
z[k + 1] = np.inf | |
if k < 0: | |
return | |
k = 0 | |
for x in range(nx): | |
while z[k + 1] < x: | |
k += 1 | |
vk = v[k] | |
dx = x - vk | |
d[x] = np.sqrt(dx * dx + f[vk]) | |
t.stop("_propagate_1d") | |
@njit("(f8[:, :],)", cache=True, parallel=True, nogil=True) | |
def _propagate_distance(distance): | |
ny, nx = distance.shape | |
for x in prange(nx): | |
_propagate_1d_first_pass(distance[:, x]) | |
v = np.zeros((ny, nx), dtype=np.int32) | |
z = np.zeros((ny, nx + 1)) | |
f = np.zeros((ny, nx)) | |
for y in prange(ny): | |
_propagate_1d(distance[y], v[y], z[y], f[y]) | |
t.stop("_propagate_distance") | |
@njit("f8[:, :](f8[:, :], i8)", cache=True, parallel=True) | |
def boxfilter_rows_valid(src, r): | |
m, n = src.shape | |
dst = np.zeros((m, n - 2 * r)) | |
for i in prange(m): | |
for j_dst in range(1): | |
s = 0.0 | |
for j_src in range(j_dst, j_dst + 2 * r + 1): | |
s += src[i, j_src] | |
dst[i, j_dst] = s | |
for j_dst in range(1, dst.shape[1]): | |
j_src = j_dst - 1 | |
s -= src[i, j_src] | |
j_src = j_dst + 2 * r | |
s += src[i, j_src] | |
dst[i, j_dst] = s | |
return dst | |
t.stop("boxfilter_rows_valid") | |
@njit("f8[:, :](f8[:, :], i8)", cache=True, parallel=True) | |
def boxfilter_rows_same(src, r): | |
m, n = src.shape | |
dst = np.zeros((m, n)) | |
for i in prange(m): | |
for j_dst in range(1): | |
s = 0.0 | |
for j_src in range(j_dst + r + 1): | |
s += src[i, j_src] | |
dst[i, j_dst] = s | |
for j_dst in range(1, r + 1): | |
s += src[i, j_dst + r] | |
dst[i, j_dst] = s | |
for j_dst in range(r + 1, n - r): | |
s -= src[i, j_dst - r - 1] | |
s += src[i, j_dst + r] | |
dst[i, j_dst] = s | |
for j_dst in range(n - r, n): | |
s -= src[i, j_dst - r - 1] | |
dst[i, j_dst] = s | |
return dst | |
t.stop("boxfilter_rows_same") | |
@njit("f8[:, :](f8[:, :], i8)", cache=True, parallel=True) | |
def boxfilter_rows_full(src, r): | |
m, n = src.shape | |
dst = np.zeros((m, n + 2 * r)) | |
for i in prange(m): | |
for j_dst in range(1): | |
s = 0.0 | |
for j_src in range(j_dst + r + 1 - r): | |
s += src[i, j_src] | |
dst[i, j_dst] = s | |
for j_dst in range(1, 2 * r + 1): | |
s += src[i, j_dst] | |
dst[i, j_dst] = s | |
for j_dst in range(2 * r + 1, dst.shape[1] - 2 * r): | |
s -= src[i, j_dst - r - r - 1] | |
s += src[i, j_dst] | |
dst[i, j_dst] = s | |
for j_dst in range(dst.shape[1] - 2 * r, dst.shape[1]): | |
s -= src[i, j_dst - r - r - 1] | |
dst[i, j_dst] = s | |
return dst | |
t.stop("boxfilter_rows_full") | |
@njit("void(f8[:, :, :], f8, i8, f8[:, :, :], i8[:], i8[:], b1[:, :])", cache=True, nogil=True) | |
def _cf_laplacian(image, epsilon, r, values, indices, indptr, is_known): | |
h, w, d = image.shape | |
assert d == 3 | |
size = 2 * r + 1 | |
window_area = size * size | |
for yi in range(h): | |
for xi in range(w): | |
i = xi + yi * w | |
k = i * (4 * r + 1) ** 2 | |
for yj in range(yi - 2 * r, yi + 2 * r + 1): | |
for xj in range(xi - 2 * r, xi + 2 * r + 1): | |
j = xj + yj * w | |
if 0 <= xj < w and 0 <= yj < h: | |
indices[k] = j | |
k += 1 | |
indptr[i + 1] = k | |
# Centered and normalized window colors | |
c = np.zeros((2 * r + 1, 2 * r + 1, 3)) | |
# For each pixel of image | |
for y in range(r, h - r): | |
for x in range(r, w - r): | |
if np.all(is_known[y - r : y + r + 1, x - r : x + r + 1]): | |
continue | |
# For each color channel | |
for dc in range(3): | |
# Calculate sum of color channel in window | |
s = 0.0 | |
for dy in range(size): | |
for dx in range(size): | |
s += image[y + dy - r, x + dx - r, dc] | |
# Calculate centered window color | |
for dy in range(2 * r + 1): | |
for dx in range(2 * r + 1): | |
c[dy, dx, dc] = ( | |
image[y + dy - r, x + dx - r, dc] - s / window_area | |
) | |
# Calculate covariance matrix over color channels with epsilon regularization | |
a00 = epsilon | |
a01 = 0.0 | |
a02 = 0.0 | |
a11 = epsilon | |
a12 = 0.0 | |
a22 = epsilon | |
for dy in range(size): | |
for dx in range(size): | |
a00 += c[dy, dx, 0] * c[dy, dx, 0] | |
a01 += c[dy, dx, 0] * c[dy, dx, 1] | |
a02 += c[dy, dx, 0] * c[dy, dx, 2] | |
a11 += c[dy, dx, 1] * c[dy, dx, 1] | |
a12 += c[dy, dx, 1] * c[dy, dx, 2] | |
a22 += c[dy, dx, 2] * c[dy, dx, 2] | |
a00 /= window_area | |
a01 /= window_area | |
a02 /= window_area | |
a11 /= window_area | |
a12 /= window_area | |
a22 /= window_area | |
det = ( | |
a00 * a12 * a12 | |
+ a01 * a01 * a22 | |
+ a02 * a02 * a11 | |
- a00 * a11 * a22 | |
- 2 * a01 * a02 * a12 | |
) | |
inv_det = 1.0 / det | |
# Calculate inverse covariance matrix | |
m00 = (a12 * a12 - a11 * a22) * inv_det | |
m01 = (a01 * a22 - a02 * a12) * inv_det | |
m02 = (a02 * a11 - a01 * a12) * inv_det | |
m11 = (a02 * a02 - a00 * a22) * inv_det | |
m12 = (a00 * a12 - a01 * a02) * inv_det | |
m22 = (a01 * a01 - a00 * a11) * inv_det | |
# For each pair ((xi, yi), (xj, yj)) in a (2 r + 1)x(2 r + 1) window | |
for dyi in range(2 * r + 1): | |
for dxi in range(2 * r + 1): | |
s = c[dyi, dxi, 0] | |
t = c[dyi, dxi, 1] | |
u = c[dyi, dxi, 2] | |
c0 = m00 * s + m01 * t + m02 * u | |
c1 = m01 * s + m11 * t + m12 * u | |
c2 = m02 * s + m12 * t + m22 * u | |
for dyj in range(2 * r + 1): | |
for dxj in range(2 * r + 1): | |
xi = x + dxi - r | |
yi = y + dyi - r | |
xj = x + dxj - r | |
yj = y + dyj - r | |
i = xi + yi * w | |
j = xj + yj * w | |
# Calculate contribution of pixel pair to L_ij | |
temp = ( | |
c0 * c[dyj, dxj, 0] | |
+ c1 * c[dyj, dxj, 1] | |
+ c2 * c[dyj, dxj, 2] | |
) | |
value = (1.0 if (i == j) else 0.0) - ( | |
1 + temp | |
) / window_area | |
dx = xj - xi + 2 * r | |
dy = yj - yi + 2 * r | |
values[i, dy, dx] += value | |
t.stop("_cf_laplacian") | |
@njit("void(f4[:, :, :], f4[:, :, :])", cache=True, nogil=True, parallel=True) | |
def _resize_nearest_multichannel(dst, src): | |
""" | |
Internal method. | |
Resize image src to dst using nearest neighbors filtering. | |
Images must have multiple color channels, i.e. :code:`len(shape) == 3`. | |
Parameters | |
---------- | |
dst: numpy.ndarray of type np.float32 | |
output image | |
src: numpy.ndarray of type np.float32 | |
input image | |
""" | |
h_src, w_src, depth = src.shape | |
h_dst, w_dst, depth = dst.shape | |
for y_dst in prange(h_dst): | |
for x_dst in range(w_dst): | |
x_src = max(0, min(w_src - 1, x_dst * w_src // w_dst)) | |
y_src = max(0, min(h_src - 1, y_dst * h_src // h_dst)) | |
for c in range(depth): | |
dst[y_dst, x_dst, c] = src[y_src, x_src, c] | |
t.stop("_resize_nearest_multichannel") | |
@njit("void(f4[:, :], f4[:, :])", cache=True, nogil=True, parallel=True) | |
def _resize_nearest(dst, src): | |
""" | |
Internal method. | |
Resize image src to dst using nearest neighbors filtering. | |
Images must be grayscale, i.e. :code:`len(shape) == 3`. | |
Parameters | |
---------- | |
dst: numpy.ndarray of type np.float32 | |
output image | |
src: numpy.ndarray of type np.float32 | |
input image | |
""" | |
h_src, w_src = src.shape | |
h_dst, w_dst = dst.shape | |
for y_dst in prange(h_dst): | |
for x_dst in range(w_dst): | |
x_src = max(0, min(w_src - 1, x_dst * w_src // w_dst)) | |
y_src = max(0, min(h_src - 1, y_dst * h_src // h_dst)) | |
dst[y_dst, x_dst] = src[y_src, x_src] | |
t.stop("_resize_nearest") | |
@njit("Tuple((f4[:, :, :], f4[:, :, :]))(f4[:, :, :], f4[:, :], f4, i4, i4, i4, f4)", cache=True, nogil=True) | |
def _estimate_fb_ml( | |
input_image, | |
input_alpha, | |
regularization, | |
n_small_iterations, | |
n_big_iterations, | |
small_size, | |
gradient_weight, | |
): | |
h0, w0, depth = input_image.shape | |
dtype = np.float32 | |
w_prev = 1 | |
h_prev = 1 | |
# Compute average foreground and background color | |
F_mean_color = np.zeros(depth, dtype=dtype) | |
B_mean_color = np.zeros(depth, dtype=dtype) | |
F_count = 0 | |
B_count = 0 | |
for y in range(h0): | |
for x in range(w0): | |
if input_alpha[y, x] > 0.9: | |
for c in range(depth): | |
F_mean_color[c] += input_image[y, x, c] | |
F_count += 1 | |
if input_alpha[y, x] < 0.1: | |
for c in range(depth): | |
B_mean_color[c] += input_image[y, x, c] | |
B_count += 1 | |
F_mean_color /= F_count + 1e-5 | |
B_mean_color /= B_count + 1e-5 | |
# Initialize initial foreground and background with average color | |
F_prev = np.zeros((h_prev, w_prev, depth), dtype=dtype) + F_mean_color | |
B_prev = np.zeros((h_prev, w_prev, depth), dtype=dtype) + B_mean_color | |
n_levels = int(np.ceil(np.log2(max(w0, h0)))) | |
for i_level in range(n_levels + 1): | |
w = round(w0 ** (i_level / n_levels)) | |
h = round(h0 ** (i_level / n_levels)) | |
image = np.empty((h, w, depth), dtype=dtype) | |
alpha = np.empty((h, w), dtype=dtype) | |
_resize_nearest_multichannel(image, input_image) | |
_resize_nearest(alpha, input_alpha) | |
F = np.empty((h, w, depth), dtype=dtype) | |
B = np.empty((h, w, depth), dtype=dtype) | |
_resize_nearest_multichannel(F, F_prev) | |
_resize_nearest_multichannel(B, B_prev) | |
if w <= small_size and h <= small_size: | |
n_iter = n_small_iterations | |
else: | |
n_iter = n_big_iterations | |
b = np.zeros((2, depth), dtype=dtype) | |
dx = [-1, 1, 0, 0] | |
dy = [0, 0, -1, 1] | |
for i_iter in range(n_iter): | |
for y in prange(h): | |
for x in range(w): | |
a0 = alpha[y, x] | |
a1 = 1.0 - a0 | |
a00 = a0 * a0 | |
a01 = a0 * a1 | |
# a10 = a01 can be omitted due to symmetry of matrix | |
a11 = a1 * a1 | |
for c in range(depth): | |
b[0, c] = a0 * image[y, x, c] | |
b[1, c] = a1 * image[y, x, c] | |
for d in range(4): | |
x2 = max(0, min(w - 1, x + dx[d])) | |
y2 = max(0, min(h - 1, y + dy[d])) | |
gradient = abs(a0 - alpha[y2, x2]) | |
da = regularization + gradient_weight * gradient | |
a00 += da | |
a11 += da | |
for c in range(depth): | |
b[0, c] += da * F[y2, x2, c] | |
b[1, c] += da * B[y2, x2, c] | |
determinant = a00 * a11 - a01 * a01 | |
inv_det = 1.0 / determinant | |
b00 = inv_det * a11 | |
b01 = inv_det * -a01 | |
b11 = inv_det * a00 | |
for c in range(depth): | |
F_c = b00 * b[0, c] + b01 * b[1, c] | |
B_c = b01 * b[0, c] + b11 * b[1, c] | |
F_c = max(0.0, min(1.0, F_c)) | |
B_c = max(0.0, min(1.0, B_c)) | |
F[y, x, c] = F_c | |
B[y, x, c] = B_c | |
F_prev = F | |
B_prev = B | |
w_prev = w | |
h_prev = h | |
return F, B | |
t.stop("_estimate_fb_ml") | |
@njit("f4(f4[::1], f4[::1], f4[::1])", cache=True, nogil=True) | |
def estimate_alpha(I, F, B): | |
fb0 = F[0] - B[0] | |
fb1 = F[1] - B[1] | |
fb2 = F[2] - B[2] | |
ib0 = I[0] - B[0] | |
ib1 = I[1] - B[1] | |
ib2 = I[2] - B[2] | |
denom = fb0 * fb0 + fb1 * fb1 + fb2 * fb2 + 1e-5 | |
alpha = (ib0 * fb0 + ib1 * fb1 + ib2 * fb2) / denom | |
alpha = max(0.0, min(1.0, alpha)) | |
return alpha | |
t.stop("estimate_alpha") | |
@njit("f4(f4[::1], f4[::1])", cache=True, nogil=True) | |
def inner(a, b): | |
s = 0.0 | |
for i in range(len(a)): | |
s += a[i] * b[i] | |
return s | |
t.stop("inner") | |
@njit("f4(f4[::1], f4[::1], f4[::1])", cache=True, nogil=True) | |
def Mp2(I, F, B): | |
a = estimate_alpha(I, F, B) | |
d0 = a * F[0] + (1.0 - a) * B[0] - I[0] | |
d1 = a * F[1] + (1.0 - a) * B[1] - I[1] | |
d2 = a * F[2] + (1.0 - a) * B[2] - I[2] | |
return d0 * d0 + d1 * d1 + d2 * d2 | |
t.stop("Mp2") | |
@njit("f4(f4[:, :, ::1], i8, i8, f4[::1], f4[::1], i8)", cache=True, nogil=True) | |
def Np(image, x, y, F, B, r): | |
h, w = image.shape[:2] | |
result = 0.0 | |
for y2 in range(y - r, y + r + 1): | |
y2 = max(0, min(h - 1, y2)) | |
for x2 in range(x - r, x + r + 1): | |
x2 = max(0, min(w - 1, x2)) | |
result += Mp2(image[y2, x2], F, B) | |
return result | |
t.stop("Np") | |
@njit("f4(f4[:, :, ::1], i8, i8, i8, i8)", cache=True, nogil=True) | |
def Ep(image, px, py, sx, sy): | |
result = 0.0 | |
spx = sx - px | |
spy = sy - py | |
d = np.hypot(spx, spy) | |
if d == 0.0: return 0.0 | |
num_steps = int(np.ceil(d)) | |
num_steps = max(1, min(10, num_steps)) | |
step_x = spx / num_steps | |
step_y = spy / num_steps | |
h, w = image.shape[:2] | |
for i in range(num_steps + 1): | |
qx = int(px + float(i) * step_x) | |
qy = int(py + float(i) * step_x) | |
q_l = max(0, min(w - 1, qx - 1)) | |
q_r = max(0, min(w - 1, qx + 1)) | |
q_u = max(0, min(h - 1, qy + 1)) | |
q_d = max(0, min(h - 1, qy - 1)) | |
qx = max(0, min(w - 1, qx)) | |
qy = max(0, min(h - 1, qy)) | |
Ix0 = 0.5 * (image[qy, q_r, 0] - image[qy, q_l, 0]) | |
Ix1 = 0.5 * (image[qy, q_r, 1] - image[qy, q_l, 1]) | |
Ix2 = 0.5 * (image[qy, q_r, 2] - image[qy, q_l, 2]) | |
Iy0 = 0.5 * (image[q_u, qx, 0] - image[q_d, qx, 0]) | |
Iy1 = 0.5 * (image[q_u, qx, 1] - image[q_d, qx, 1]) | |
Iy2 = 0.5 * (image[q_u, qx, 2] - image[q_d, qx, 2]) | |
v0 = step_x * Ix0 + step_y * Iy0 | |
v1 = step_x * Ix1 + step_y * Iy1 | |
v2 = step_x * Ix2 + step_y * Iy2 | |
result += np.sqrt(v0 * v0 + v1 * v1 + v2 * v2) | |
return result | |
t.stop("Ep") | |
@njit("f4(f4[::1], f4[::1])", cache=True, nogil=True) | |
def dist(a, b): | |
d2 = 0.0 | |
for i in range(a.shape[0]): | |
d2 += (a[i] - b[i]) ** 2 | |
return np.sqrt(d2) | |
t.stop("dist") | |
@njit("f4(f4[::1])", cache=True, nogil=True) | |
def length(a): | |
return np.sqrt(inner(a, a)) | |
t.stop("length") | |
@njit("void(f4[:, ::1], f4[:, ::1], f4[:, :, ::1], i8, f4)", cache=True, parallel=True, nogil=True) | |
def expand_trimap(expanded_trimap, trimap, image, k_i, k_c): | |
# NB: Description in paper does not match published test images. | |
# The radius appears to be larger and expanded trimap is sparser. | |
h, w = trimap.shape | |
for y, x in pndindex((h, w)): | |
if trimap[y, x] == 0 or trimap[y, x] == 1: continue | |
closest = np.inf | |
for y2 in range(y - k_i, y + k_i + 1): | |
for x2 in range(x - k_i, x + k_i + 1): | |
if x2 < 0 or x2 >= w or y2 < 0 or y2 >= h: continue | |
if trimap[y2, x2] != 0 and trimap[y2, x2] != 1: continue | |
dr = image[y, x, 0] - image[y2, x2, 0] | |
dg = image[y, x, 1] - image[y2, x2, 1] | |
db = image[y, x, 2] - image[y2, x2, 2] | |
color_distance = np.sqrt(dr * dr + dg * dg + db * db) | |
spatial_distance = np.hypot(x - x2, y - y2) | |
if color_distance > k_c: continue | |
if spatial_distance > k_i: continue | |
if spatial_distance < closest: | |
closest = spatial_distance | |
expanded_trimap[y, x] = trimap[y2, x2] | |
t.stop("expand_trimap") | |
@njit("void(f4[:, :, ::1], f4[:, :, ::1], f4[:, ::1], f4[:, :, ::1], f4[:, ::1], i8, f4, f4, f4, f4, i8)", cache=True, parallel=True, nogil=True) | |
def sample_gathering( | |
gathering_F, | |
gathering_B, | |
gathering_alpha, | |
image, | |
trimap, | |
num_angles, | |
eN, | |
eA, | |
ef, | |
eb, | |
Np_radius, | |
): | |
h, w = trimap.shape | |
max_steps = 2 * max(w, h) | |
for y, x in pndindex((h, w)): | |
fg_samples = np.zeros((num_angles, 3), dtype=np.float32) | |
fg_samples_xy = np.zeros((num_angles, 2), dtype=np.int32) | |
bg_samples = np.zeros((num_angles, 3), dtype=np.float32) | |
bg_samples_xy = np.zeros((num_angles, 2), dtype=np.int32) | |
C_p = image[y, x] | |
gathering_alpha[y, x] = trimap[y, x] | |
if trimap[y, x] == 0: | |
gathering_B[y, x] = C_p | |
continue | |
if trimap[y, x] == 1: | |
gathering_F[y, x] = C_p | |
continue | |
# Fixed start angles in 8-by-8 grid for reproducible tests | |
n = 8 | |
i = (x % n) + (y % n) * n | |
# Shuffle (99991 is a prime number) | |
i = (i * 99991) % (n * n) | |
start_angle = 2.0 * np.pi / (n * n) * i | |
num_fg_samples = 0 | |
num_bg_samples = 0 | |
for i in range(num_angles): | |
angle = 2.0 * np.pi / num_angles * i + start_angle | |
c = np.cos(angle) | |
s = np.sin(angle) | |
has_fg = False | |
has_bg = False | |
for step in range(max_steps): | |
if has_fg and has_bg: break | |
x2 = int(x + step * c) | |
y2 = int(y + step * s) | |
if x2 < 0 or y2 < 0 or x2 >= w or y2 >= h: break | |
if not has_fg and trimap[y2, x2] == 1: | |
fg_samples[num_fg_samples] = image[y2, x2] | |
fg_samples_xy[num_fg_samples, 0] = x2 | |
fg_samples_xy[num_fg_samples, 1] = y2 | |
num_fg_samples += 1 | |
has_fg = True | |
if not has_bg and trimap[y2, x2] == 0: | |
bg_samples[num_bg_samples] = image[y2, x2] | |
bg_samples_xy[num_bg_samples, 0] = x2 | |
bg_samples_xy[num_bg_samples, 1] = y2 | |
num_bg_samples += 1 | |
has_bg = True | |
if num_fg_samples == 0: | |
fg_samples[num_fg_samples] = gathering_F[y, x] | |
fg_samples_xy[num_fg_samples, 0] = x | |
fg_samples_xy[num_fg_samples, 1] = y | |
num_fg_samples += 1 | |
if num_bg_samples == 0: | |
bg_samples[num_bg_samples] = gathering_B[y, x] | |
bg_samples_xy[num_bg_samples, 0] = x | |
bg_samples_xy[num_bg_samples, 1] = y | |
num_bg_samples += 1 | |
min_Ep_f = np.inf | |
min_Ep_b = np.inf | |
for i in range(num_fg_samples): | |
Ep_f = Ep(image, x, y, fg_samples_xy[i, 0], fg_samples_xy[i, 1]) | |
min_Ep_f = min(min_Ep_f, Ep_f) | |
for j in range(num_bg_samples): | |
Ep_b = Ep(image, x, y, bg_samples_xy[j, 0], bg_samples_xy[j, 1]) | |
min_Ep_b = min(min_Ep_b, Ep_b) | |
PF_p = min_Ep_b / (min_Ep_f + min_Ep_b + 1e-5) | |
min_cost = np.inf | |
# Find best foreground/background pair | |
for i in range(num_fg_samples): | |
for j in range(num_bg_samples): | |
F = fg_samples[i] | |
B = bg_samples[j] | |
alpha_p = estimate_alpha(C_p, F, B) | |
Ap = PF_p + (1.0 - 2.0 * PF_p) * alpha_p | |
Dp_f = np.hypot(x - fg_samples_xy[i, 0], y - fg_samples_xy[i, 1]) | |
Dp_b = np.hypot(x - bg_samples_xy[j, 0], y - bg_samples_xy[j, 1]) | |
g_p = ( | |
Np(image, x, y, F, B, Np_radius)**eN * | |
Ap**eA * | |
Dp_f**ef * | |
Dp_b**eb) | |
if min_cost > g_p: | |
min_cost = g_p | |
gathering_alpha[y, x] = alpha_p | |
gathering_F[y, x] = F | |
gathering_B[y, x] = B | |
t.stop("sample_gathering") | |
@njit("void(f4[:, :, ::1], f4[:, :, ::1], f4[:, ::1], f4[:, :, ::1], f4[:, :, ::1], f4[:, :, ::1], f4[:, ::1], i8)", cache=True, parallel=False, nogil=True) | |
def sample_refinement( | |
refined_F, | |
refined_B, | |
refined_alpha, | |
gathering_F, | |
gathering_B, | |
image, | |
trimap, | |
radius, | |
): | |
h, w = trimap.shape | |
refined_F[:] = gathering_F | |
refined_B[:] = gathering_B | |
refined_alpha[:] = trimap | |
for y, x in pndindex((h, w)): | |
C_p = image[y, x] | |
if trimap[y, x] == 0 or trimap[y, x] == 1: | |
continue | |
max_samples = 3 | |
sample_F = np.zeros((max_samples, 3), dtype=np.float32) | |
sample_B = np.zeros((max_samples, 3), dtype=np.float32) | |
sample_cost = np.zeros(max_samples, dtype=np.float32) | |
sample_cost[:] = np.inf | |
for dy in range(-radius, radius + 1): | |
for dx in range(-radius, radius + 1): | |
x2 = x + dx | |
y2 = y + dy | |
if 0 <= x2 < w and 0 <= y2 < h: | |
F_q = gathering_F[y2, x2] | |
B_q = gathering_B[y2, x2] | |
cost = Mp2(C_p, F_q, B_q) | |
i = np.argmax(sample_cost) | |
if cost < sample_cost[i]: | |
sample_cost[i] = cost | |
sample_F[i] = F_q | |
sample_B[i] = B_q | |
F_mean = sample_F.sum(axis=0) / max_samples | |
B_mean = sample_B.sum(axis=0) / max_samples | |
refined_F[y, x] = F_mean | |
refined_B[y, x] = B_mean | |
refined_alpha[y, x] = estimate_alpha(C_p, F_mean, B_mean) | |
t.stop("sample_refinement") | |
@njit("void(f4[:, :, ::1], f4[:, :, ::1], f4[:, ::1], f4[:, :, ::1], f4[:, :, ::1], f4[:, ::1], f4[:, :, ::1], f4[:, ::1], i8, i8, i8, f4, f4, f4)", cache=True, parallel=True, nogil=True) | |
def local_smoothing( | |
final_F, | |
final_B, | |
final_alpha, | |
refined_F, | |
refined_B, | |
refined_alpha, | |
image, | |
trimap, | |
radius1, | |
radius2, | |
radius3, | |
sigma_sq1, | |
sigma_sq2, | |
sigma_sq3, | |
): | |
h, w = trimap.shape | |
final_confidence = np.zeros((h, w), dtype=np.float32) | |
W_FB = np.zeros((h, w), dtype=np.float32) | |
low_frequency_alpha = np.zeros((h, w), dtype=np.float32) | |
final_F[:] = refined_F | |
final_B[:] = refined_B | |
for y, x in pndindex((h, w)): | |
C_p = image[y, x] | |
if trimap[y, x] == 0 or trimap[y, x] == 1: | |
continue | |
F_p = np.zeros(3, dtype=np.float32) | |
B_p = np.zeros(3, dtype=np.float32) | |
sum_F = 0.0 | |
sum_B = 0.0 | |
alpha_p = refined_alpha[y, x] | |
for dy in range(-radius1, radius1 + 1): | |
for dx in range(-radius1, radius1 + 1): | |
x2 = x + dx | |
y2 = y + dy | |
if 0 <= x2 < w and 0 <= y2 < h: | |
# NB: Gaussian not normalized, not using confidence | |
Wc_pq = np.exp(-1.0 / sigma_sq1 * (dx * dx + dy * dy)) | |
if x != x2 or y != y2: | |
Wc_pq *= abs(refined_alpha[y, x] - refined_alpha[y2, x2]) | |
alpha_q = refined_alpha[y2, x2] | |
W_F = Wc_pq * alpha_q | |
W_B = Wc_pq * (1.0 - alpha_q) | |
W_F = max(W_F, 1e-5) | |
W_B = max(W_B, 1e-5) | |
sum_F += W_F | |
sum_B += W_B | |
for c in range(3): | |
F_p[c] += W_F * refined_F[y2, x2, c] | |
B_p[c] += W_B * refined_B[y2, x2, c] | |
F_p /= sum_F | |
B_p /= sum_B | |
final_F[y, x] = F_p | |
final_B[y, x] = B_p | |
final_alpha[y, x] = estimate_alpha(C_p, F_p, B_p) | |
# NB: Not using confidence | |
W_FB[y, x] = alpha_p * (1.0 - alpha_p) | |
for y, x in pndindex((h, w)): | |
C_p = image[y, x] | |
if trimap[y, x] == 0 or trimap[y, x] == 1: | |
final_confidence[y, x] = trimap[y, x] | |
continue | |
D_FB = 0.0 | |
weight_sum = 0.0 | |
for dy in range(-radius2, radius2 + 1): | |
for dx in range(-radius2, radius2 + 1): | |
x2 = x + dx | |
y2 = y + dy | |
if 0 <= x2 < w and 0 <= y2 < h: | |
weight_sum += W_FB[y2, x2] | |
D_FB += W_FB[y2, x2] * dist(final_F[y2, x2], final_B[y2, x2]) | |
D_FB /= weight_sum + 1e-5 | |
D_FB += 1e-5 | |
FB_dist = dist(final_F[y, x], final_B[y, x]) | |
F_p = final_F[y, x] | |
B_p = final_B[y, x] | |
final_confidence[y, x] = min(1.0, FB_dist / D_FB) * np.exp(-1.0 / sigma_sq2 * np.sqrt(Mp2(C_p, F_p, B_p))) | |
for y, x in pndindex((h, w)): | |
if trimap[y, x] == 0 or trimap[y, x] == 1: | |
final_alpha[y, x] = trimap[y, x] | |
continue | |
alpha_sum = 0.0 | |
weight_sum = 0.0 | |
for dy in range(-radius3, radius3 + 1): | |
for dx in range(-radius3, radius3 + 1): | |
x2 = x + dx | |
y2 = y + dy | |
if 0 <= x2 < w and 0 <= y2 < h: | |
# NB: Gaussian not normalized, not using final_confidence(x2, y2) | |
D_image_squared = dx * dx + dy * dy | |
is_known = trimap[y2, x2] == 0 or trimap[y2, x2] == 1 | |
W_alpha = np.exp(-1.0 / sigma_sq3 * (dx * dx + dy * dy)) + is_known | |
alpha_sum += W_alpha * refined_alpha[y2, x2] | |
weight_sum += W_alpha | |
low_frequency_alpha[y, x] = alpha_sum / weight_sum | |
final_alpha[y, x] = final_confidence[y, x] * final_alpha[y, x] + (1.0 - final_confidence[y, x]) * low_frequency_alpha[y, x] | |
t.stop("local_smoothing") | |
# Benchmark matmul to get CPU baseline performance | |
n = 1024 | |
A = np.random.rand(n, n) | |
t.stop("rand") | |
for _ in range(10): | |
A @ A | |
t.stop("matmul") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment