Skip to content

Instantly share code, notes, and snippets.

@99991
Last active September 16, 2024 06:09
Show Gist options
  • Save 99991/c9de403245bb71288001452045df0c6c to your computer and use it in GitHub Desktop.
Save 99991/c9de403245bb71288001452045df0c6c to your computer and use it in GitHub Desktop.
import shutil
# Remove cache if it exists
try:
shutil.rmtree("__pycache__")
except FileNotFoundError:
pass
try:
with open("/proc/cpuinfo") as f:
print("CPU info:")
print("-" * 60)
print(f.read().strip().split("\n\n")[-1])
print("-" * 60)
print("\n\n")
except FileNotFoundError as e:
print(e)
import time
class Timer:
def __init__(self):
self.start()
def start(self):
self.t = time.perf_counter()
def stop(self, msg):
elapsed_time = time.perf_counter() - self.t
print(f"{elapsed_time:.3f} sec - {msg}")
self.t = time.perf_counter()
t = Timer()
import numpy as np
t.stop("import numpy")
from numba import njit, prange, pndindex
t.stop("import numba")
@njit("i8(i8[:], i8[:], i8[:], i8[:], i8[:], f4[:, :, :], f4[:], f4[:, :], i8[:], i8)", cache=True, nogil=True)
def _make_tree(
i0_inds,
i1_inds,
less_inds,
more_inds,
split_dims,
bounds,
split_values,
points,
indices,
min_leaf_size,
):
dimension = points.shape[1]
# Expect log2(len(points) / min_leaf_size) depth, 1000 should be plenty
stack = np.empty(1000, np.int64)
stack_size = 0
n_nodes = 0
# min_leaf_size <= leaf_node_size < max_leaf_size
max_leaf_size = 2 * min_leaf_size
# Push i0, i1, i_node
stack[stack_size] = 0
stack_size += 1
stack[stack_size] = points.shape[0]
stack_size += 1
stack[stack_size] = n_nodes
n_nodes += 1
stack_size += 1
# While there are more tree nodes to process recursively
while stack_size > 0:
stack_size -= 1
i_node = stack[stack_size]
stack_size -= 1
i1 = stack[stack_size]
stack_size -= 1
i0 = stack[stack_size]
lo = bounds[i_node, 0]
hi = bounds[i_node, 1]
for d in range(dimension):
lo[d] = points[i0, d]
hi[d] = points[i0, d]
# Find lower and upper bounds of points for each dimension
for i in range(i0 + 1, i1):
for d in range(dimension):
lo[d] = min(lo[d], points[i, d])
hi[d] = max(hi[d], points[i, d])
# Done if node is small
if i1 - i0 <= max_leaf_size:
i0_inds[i_node] = i0
i1_inds[i_node] = i1
less_inds[i_node] = -1
more_inds[i_node] = -1
split_dims[i_node] = -1
split_values[i_node] = 0.0
else:
# Split on largest dimension
lengths = hi - lo
split_dim = np.argmax(lengths)
split_value = lo[split_dim] + 0.5 * lengths[split_dim]
# Partition i0:i1 range into points where points[i, split_dim] < split_value
i = i0
j = i1 - 1
while i < j:
while i < j and points[i, split_dim] < split_value:
i += 1
while i < j and points[j, split_dim] >= split_value:
j -= 1
# Swap points
if i < j:
for d in range(dimension):
temp = points[i, d]
points[i, d] = points[j, d]
points[j, d] = temp
temp_i_node = indices[i]
indices[i] = indices[j]
indices[j] = temp_i_node
if points[i, split_dim] < split_value:
i += 1
i_split = i
# Now it holds that:
# for i in range(i0, i_split): assert(points[i, split_dim] < split_value)
# for i in range(i_split, i1): assert(points[i, split_dim] >= split_value)
# Ensure that each node has at least min_leaf_size children
i_split = max(i0 + min_leaf_size, min(i1 - min_leaf_size, i_split))
less = n_nodes
n_nodes += 1
more = n_nodes
n_nodes += 1
# push i0, i_split, less
stack[stack_size] = i0
stack_size += 1
stack[stack_size] = i_split
stack_size += 1
stack[stack_size] = less
stack_size += 1
# push i_split, i1, more
stack[stack_size] = i_split
stack_size += 1
stack[stack_size] = i1
stack_size += 1
stack[stack_size] = more
stack_size += 1
i0_inds[i_node] = i0
i1_inds[i_node] = i1
less_inds[i_node] = less
more_inds[i_node] = more
split_dims[i_node] = split_dim
split_values[i_node] = split_value
return n_nodes
t.stop("_make_tree")
@njit("void(i8[:], i8[:], i8[:], i8[:], i8[:], f4[:, :, :], f4[:], f4[:, :], f4[:, :], i8[:, :], f4[:, :], i8)", cache=True, nogil=True, parallel=True)
def _find_knn(
i0_inds,
i1_inds,
less_inds,
more_inds,
split_dims,
bounds,
split_values,
points,
query_points,
out_indices,
out_distances,
k,
):
dimension = points.shape[1]
# For each query point
for i_query in prange(query_points.shape[0]):
query_point = query_points[i_query]
distances = out_distances[i_query]
indices = out_indices[i_query]
# Expect log2(len(points) / min_leaf_size) depth, 1000 should be plenty
stack = np.empty(1000, np.int64)
n_neighbors = 0
stack[0] = 0
stack_size = 1
# While there are nodes to visit
while stack_size > 0:
stack_size -= 1
i_node = stack[stack_size]
# If we found more neighbors than we need
if n_neighbors >= k:
# Calculate distance to bounding box of node
dist = 0.0
for d in range(dimension):
p = query_point[d]
dp = p - max(bounds[i_node, 0, d], min(bounds[i_node, 1, d], p))
dist += dp * dp
# Do nothing with this node if all points we have found so far
# are closer than the bounding box of the node.
if dist > distances[n_neighbors - 1]:
continue
# If leaf node
if split_dims[i_node] == -1:
# For each point in leaf node
for i in range(i0_inds[i_node], i1_inds[i_node]):
# Calculate distance between query point and point in node
distance = 0.0
for d in range(dimension):
dd = query_point[d] - points[i, d]
distance += dd * dd
# Find insert position
insert_pos = n_neighbors
for j in range(n_neighbors - 1, -1, -1):
if distances[j] > distance:
insert_pos = j
# Insert found point in a sorted order
if insert_pos < k:
# Move [insert_pos:k-1] one to the right to make space
for j in range(min(n_neighbors, k - 1), insert_pos, -1):
indices[j] = indices[j - 1]
distances[j] = distances[j - 1]
# Insert new neighbor
indices[insert_pos] = i
distances[insert_pos] = distance
n_neighbors = min(n_neighbors + 1, k)
else:
# Descent to child nodes
less = less_inds[i_node]
more = more_inds[i_node]
split_dim = split_dims[i_node]
# First, visit child in same bounding box as query point
if query_point[split_dim] < split_values[i_node]:
stack[stack_size] = more
stack_size += 1
stack[stack_size] = less
stack_size += 1
else:
# Next, visit other child
stack[stack_size] = less
stack_size += 1
stack[stack_size] = more
stack_size += 1
# Workaround for https://github.com/numba/numba/issues/5156
stack_size += 0
t.stop("_find_knn")
@njit("(f8[:],)", cache=True, nogil=True)
def _propagate_1d_first_pass(d):
n = len(d)
for i1 in range(1, n):
i2 = i1 - 1
d[i1] = min(d[i1], d[i2] + 1)
for i1 in range(n - 2, -1, -1):
i2 = i1 + 1
d[i1] = min(d[i1], d[i2] + 1)
t.stop("_propagate_1d_first_pass")
@njit("(f8[:], i4[:], f8[:], f8[:])", cache=True, nogil=True)
def _propagate_1d(d, v, z, f):
nx = len(d)
k = -1
s = -np.inf
for x in range(nx):
d_yx = d[x]
if d_yx == np.inf:
continue
fx = d_yx * d_yx
f[x] = fx
while k >= 0:
vk = v[k]
s = 0.5 * (fx + x * x - f[vk] - vk * vk) / (x - vk)
if s > z[k]:
break
k -= 1
k += 1
v[k] = x
z[k] = s
z[k + 1] = np.inf
if k < 0:
return
k = 0
for x in range(nx):
while z[k + 1] < x:
k += 1
vk = v[k]
dx = x - vk
d[x] = np.sqrt(dx * dx + f[vk])
t.stop("_propagate_1d")
@njit("(f8[:, :],)", cache=True, parallel=True, nogil=True)
def _propagate_distance(distance):
ny, nx = distance.shape
for x in prange(nx):
_propagate_1d_first_pass(distance[:, x])
v = np.zeros((ny, nx), dtype=np.int32)
z = np.zeros((ny, nx + 1))
f = np.zeros((ny, nx))
for y in prange(ny):
_propagate_1d(distance[y], v[y], z[y], f[y])
t.stop("_propagate_distance")
@njit("f8[:, :](f8[:, :], i8)", cache=True, parallel=True)
def boxfilter_rows_valid(src, r):
m, n = src.shape
dst = np.zeros((m, n - 2 * r))
for i in prange(m):
for j_dst in range(1):
s = 0.0
for j_src in range(j_dst, j_dst + 2 * r + 1):
s += src[i, j_src]
dst[i, j_dst] = s
for j_dst in range(1, dst.shape[1]):
j_src = j_dst - 1
s -= src[i, j_src]
j_src = j_dst + 2 * r
s += src[i, j_src]
dst[i, j_dst] = s
return dst
t.stop("boxfilter_rows_valid")
@njit("f8[:, :](f8[:, :], i8)", cache=True, parallel=True)
def boxfilter_rows_same(src, r):
m, n = src.shape
dst = np.zeros((m, n))
for i in prange(m):
for j_dst in range(1):
s = 0.0
for j_src in range(j_dst + r + 1):
s += src[i, j_src]
dst[i, j_dst] = s
for j_dst in range(1, r + 1):
s += src[i, j_dst + r]
dst[i, j_dst] = s
for j_dst in range(r + 1, n - r):
s -= src[i, j_dst - r - 1]
s += src[i, j_dst + r]
dst[i, j_dst] = s
for j_dst in range(n - r, n):
s -= src[i, j_dst - r - 1]
dst[i, j_dst] = s
return dst
t.stop("boxfilter_rows_same")
@njit("f8[:, :](f8[:, :], i8)", cache=True, parallel=True)
def boxfilter_rows_full(src, r):
m, n = src.shape
dst = np.zeros((m, n + 2 * r))
for i in prange(m):
for j_dst in range(1):
s = 0.0
for j_src in range(j_dst + r + 1 - r):
s += src[i, j_src]
dst[i, j_dst] = s
for j_dst in range(1, 2 * r + 1):
s += src[i, j_dst]
dst[i, j_dst] = s
for j_dst in range(2 * r + 1, dst.shape[1] - 2 * r):
s -= src[i, j_dst - r - r - 1]
s += src[i, j_dst]
dst[i, j_dst] = s
for j_dst in range(dst.shape[1] - 2 * r, dst.shape[1]):
s -= src[i, j_dst - r - r - 1]
dst[i, j_dst] = s
return dst
t.stop("boxfilter_rows_full")
@njit("void(f8[:, :, :], f8, i8, f8[:, :, :], i8[:], i8[:], b1[:, :])", cache=True, nogil=True)
def _cf_laplacian(image, epsilon, r, values, indices, indptr, is_known):
h, w, d = image.shape
assert d == 3
size = 2 * r + 1
window_area = size * size
for yi in range(h):
for xi in range(w):
i = xi + yi * w
k = i * (4 * r + 1) ** 2
for yj in range(yi - 2 * r, yi + 2 * r + 1):
for xj in range(xi - 2 * r, xi + 2 * r + 1):
j = xj + yj * w
if 0 <= xj < w and 0 <= yj < h:
indices[k] = j
k += 1
indptr[i + 1] = k
# Centered and normalized window colors
c = np.zeros((2 * r + 1, 2 * r + 1, 3))
# For each pixel of image
for y in range(r, h - r):
for x in range(r, w - r):
if np.all(is_known[y - r : y + r + 1, x - r : x + r + 1]):
continue
# For each color channel
for dc in range(3):
# Calculate sum of color channel in window
s = 0.0
for dy in range(size):
for dx in range(size):
s += image[y + dy - r, x + dx - r, dc]
# Calculate centered window color
for dy in range(2 * r + 1):
for dx in range(2 * r + 1):
c[dy, dx, dc] = (
image[y + dy - r, x + dx - r, dc] - s / window_area
)
# Calculate covariance matrix over color channels with epsilon regularization
a00 = epsilon
a01 = 0.0
a02 = 0.0
a11 = epsilon
a12 = 0.0
a22 = epsilon
for dy in range(size):
for dx in range(size):
a00 += c[dy, dx, 0] * c[dy, dx, 0]
a01 += c[dy, dx, 0] * c[dy, dx, 1]
a02 += c[dy, dx, 0] * c[dy, dx, 2]
a11 += c[dy, dx, 1] * c[dy, dx, 1]
a12 += c[dy, dx, 1] * c[dy, dx, 2]
a22 += c[dy, dx, 2] * c[dy, dx, 2]
a00 /= window_area
a01 /= window_area
a02 /= window_area
a11 /= window_area
a12 /= window_area
a22 /= window_area
det = (
a00 * a12 * a12
+ a01 * a01 * a22
+ a02 * a02 * a11
- a00 * a11 * a22
- 2 * a01 * a02 * a12
)
inv_det = 1.0 / det
# Calculate inverse covariance matrix
m00 = (a12 * a12 - a11 * a22) * inv_det
m01 = (a01 * a22 - a02 * a12) * inv_det
m02 = (a02 * a11 - a01 * a12) * inv_det
m11 = (a02 * a02 - a00 * a22) * inv_det
m12 = (a00 * a12 - a01 * a02) * inv_det
m22 = (a01 * a01 - a00 * a11) * inv_det
# For each pair ((xi, yi), (xj, yj)) in a (2 r + 1)x(2 r + 1) window
for dyi in range(2 * r + 1):
for dxi in range(2 * r + 1):
s = c[dyi, dxi, 0]
t = c[dyi, dxi, 1]
u = c[dyi, dxi, 2]
c0 = m00 * s + m01 * t + m02 * u
c1 = m01 * s + m11 * t + m12 * u
c2 = m02 * s + m12 * t + m22 * u
for dyj in range(2 * r + 1):
for dxj in range(2 * r + 1):
xi = x + dxi - r
yi = y + dyi - r
xj = x + dxj - r
yj = y + dyj - r
i = xi + yi * w
j = xj + yj * w
# Calculate contribution of pixel pair to L_ij
temp = (
c0 * c[dyj, dxj, 0]
+ c1 * c[dyj, dxj, 1]
+ c2 * c[dyj, dxj, 2]
)
value = (1.0 if (i == j) else 0.0) - (
1 + temp
) / window_area
dx = xj - xi + 2 * r
dy = yj - yi + 2 * r
values[i, dy, dx] += value
t.stop("_cf_laplacian")
@njit("void(f4[:, :, :], f4[:, :, :])", cache=True, nogil=True, parallel=True)
def _resize_nearest_multichannel(dst, src):
"""
Internal method.
Resize image src to dst using nearest neighbors filtering.
Images must have multiple color channels, i.e. :code:`len(shape) == 3`.
Parameters
----------
dst: numpy.ndarray of type np.float32
output image
src: numpy.ndarray of type np.float32
input image
"""
h_src, w_src, depth = src.shape
h_dst, w_dst, depth = dst.shape
for y_dst in prange(h_dst):
for x_dst in range(w_dst):
x_src = max(0, min(w_src - 1, x_dst * w_src // w_dst))
y_src = max(0, min(h_src - 1, y_dst * h_src // h_dst))
for c in range(depth):
dst[y_dst, x_dst, c] = src[y_src, x_src, c]
t.stop("_resize_nearest_multichannel")
@njit("void(f4[:, :], f4[:, :])", cache=True, nogil=True, parallel=True)
def _resize_nearest(dst, src):
"""
Internal method.
Resize image src to dst using nearest neighbors filtering.
Images must be grayscale, i.e. :code:`len(shape) == 3`.
Parameters
----------
dst: numpy.ndarray of type np.float32
output image
src: numpy.ndarray of type np.float32
input image
"""
h_src, w_src = src.shape
h_dst, w_dst = dst.shape
for y_dst in prange(h_dst):
for x_dst in range(w_dst):
x_src = max(0, min(w_src - 1, x_dst * w_src // w_dst))
y_src = max(0, min(h_src - 1, y_dst * h_src // h_dst))
dst[y_dst, x_dst] = src[y_src, x_src]
t.stop("_resize_nearest")
@njit("Tuple((f4[:, :, :], f4[:, :, :]))(f4[:, :, :], f4[:, :], f4, i4, i4, i4, f4)", cache=True, nogil=True)
def _estimate_fb_ml(
input_image,
input_alpha,
regularization,
n_small_iterations,
n_big_iterations,
small_size,
gradient_weight,
):
h0, w0, depth = input_image.shape
dtype = np.float32
w_prev = 1
h_prev = 1
# Compute average foreground and background color
F_mean_color = np.zeros(depth, dtype=dtype)
B_mean_color = np.zeros(depth, dtype=dtype)
F_count = 0
B_count = 0
for y in range(h0):
for x in range(w0):
if input_alpha[y, x] > 0.9:
for c in range(depth):
F_mean_color[c] += input_image[y, x, c]
F_count += 1
if input_alpha[y, x] < 0.1:
for c in range(depth):
B_mean_color[c] += input_image[y, x, c]
B_count += 1
F_mean_color /= F_count + 1e-5
B_mean_color /= B_count + 1e-5
# Initialize initial foreground and background with average color
F_prev = np.zeros((h_prev, w_prev, depth), dtype=dtype) + F_mean_color
B_prev = np.zeros((h_prev, w_prev, depth), dtype=dtype) + B_mean_color
n_levels = int(np.ceil(np.log2(max(w0, h0))))
for i_level in range(n_levels + 1):
w = round(w0 ** (i_level / n_levels))
h = round(h0 ** (i_level / n_levels))
image = np.empty((h, w, depth), dtype=dtype)
alpha = np.empty((h, w), dtype=dtype)
_resize_nearest_multichannel(image, input_image)
_resize_nearest(alpha, input_alpha)
F = np.empty((h, w, depth), dtype=dtype)
B = np.empty((h, w, depth), dtype=dtype)
_resize_nearest_multichannel(F, F_prev)
_resize_nearest_multichannel(B, B_prev)
if w <= small_size and h <= small_size:
n_iter = n_small_iterations
else:
n_iter = n_big_iterations
b = np.zeros((2, depth), dtype=dtype)
dx = [-1, 1, 0, 0]
dy = [0, 0, -1, 1]
for i_iter in range(n_iter):
for y in prange(h):
for x in range(w):
a0 = alpha[y, x]
a1 = 1.0 - a0
a00 = a0 * a0
a01 = a0 * a1
# a10 = a01 can be omitted due to symmetry of matrix
a11 = a1 * a1
for c in range(depth):
b[0, c] = a0 * image[y, x, c]
b[1, c] = a1 * image[y, x, c]
for d in range(4):
x2 = max(0, min(w - 1, x + dx[d]))
y2 = max(0, min(h - 1, y + dy[d]))
gradient = abs(a0 - alpha[y2, x2])
da = regularization + gradient_weight * gradient
a00 += da
a11 += da
for c in range(depth):
b[0, c] += da * F[y2, x2, c]
b[1, c] += da * B[y2, x2, c]
determinant = a00 * a11 - a01 * a01
inv_det = 1.0 / determinant
b00 = inv_det * a11
b01 = inv_det * -a01
b11 = inv_det * a00
for c in range(depth):
F_c = b00 * b[0, c] + b01 * b[1, c]
B_c = b01 * b[0, c] + b11 * b[1, c]
F_c = max(0.0, min(1.0, F_c))
B_c = max(0.0, min(1.0, B_c))
F[y, x, c] = F_c
B[y, x, c] = B_c
F_prev = F
B_prev = B
w_prev = w
h_prev = h
return F, B
t.stop("_estimate_fb_ml")
@njit("f4(f4[::1], f4[::1], f4[::1])", cache=True, nogil=True)
def estimate_alpha(I, F, B):
fb0 = F[0] - B[0]
fb1 = F[1] - B[1]
fb2 = F[2] - B[2]
ib0 = I[0] - B[0]
ib1 = I[1] - B[1]
ib2 = I[2] - B[2]
denom = fb0 * fb0 + fb1 * fb1 + fb2 * fb2 + 1e-5
alpha = (ib0 * fb0 + ib1 * fb1 + ib2 * fb2) / denom
alpha = max(0.0, min(1.0, alpha))
return alpha
t.stop("estimate_alpha")
@njit("f4(f4[::1], f4[::1])", cache=True, nogil=True)
def inner(a, b):
s = 0.0
for i in range(len(a)):
s += a[i] * b[i]
return s
t.stop("inner")
@njit("f4(f4[::1], f4[::1], f4[::1])", cache=True, nogil=True)
def Mp2(I, F, B):
a = estimate_alpha(I, F, B)
d0 = a * F[0] + (1.0 - a) * B[0] - I[0]
d1 = a * F[1] + (1.0 - a) * B[1] - I[1]
d2 = a * F[2] + (1.0 - a) * B[2] - I[2]
return d0 * d0 + d1 * d1 + d2 * d2
t.stop("Mp2")
@njit("f4(f4[:, :, ::1], i8, i8, f4[::1], f4[::1], i8)", cache=True, nogil=True)
def Np(image, x, y, F, B, r):
h, w = image.shape[:2]
result = 0.0
for y2 in range(y - r, y + r + 1):
y2 = max(0, min(h - 1, y2))
for x2 in range(x - r, x + r + 1):
x2 = max(0, min(w - 1, x2))
result += Mp2(image[y2, x2], F, B)
return result
t.stop("Np")
@njit("f4(f4[:, :, ::1], i8, i8, i8, i8)", cache=True, nogil=True)
def Ep(image, px, py, sx, sy):
result = 0.0
spx = sx - px
spy = sy - py
d = np.hypot(spx, spy)
if d == 0.0: return 0.0
num_steps = int(np.ceil(d))
num_steps = max(1, min(10, num_steps))
step_x = spx / num_steps
step_y = spy / num_steps
h, w = image.shape[:2]
for i in range(num_steps + 1):
qx = int(px + float(i) * step_x)
qy = int(py + float(i) * step_x)
q_l = max(0, min(w - 1, qx - 1))
q_r = max(0, min(w - 1, qx + 1))
q_u = max(0, min(h - 1, qy + 1))
q_d = max(0, min(h - 1, qy - 1))
qx = max(0, min(w - 1, qx))
qy = max(0, min(h - 1, qy))
Ix0 = 0.5 * (image[qy, q_r, 0] - image[qy, q_l, 0])
Ix1 = 0.5 * (image[qy, q_r, 1] - image[qy, q_l, 1])
Ix2 = 0.5 * (image[qy, q_r, 2] - image[qy, q_l, 2])
Iy0 = 0.5 * (image[q_u, qx, 0] - image[q_d, qx, 0])
Iy1 = 0.5 * (image[q_u, qx, 1] - image[q_d, qx, 1])
Iy2 = 0.5 * (image[q_u, qx, 2] - image[q_d, qx, 2])
v0 = step_x * Ix0 + step_y * Iy0
v1 = step_x * Ix1 + step_y * Iy1
v2 = step_x * Ix2 + step_y * Iy2
result += np.sqrt(v0 * v0 + v1 * v1 + v2 * v2)
return result
t.stop("Ep")
@njit("f4(f4[::1], f4[::1])", cache=True, nogil=True)
def dist(a, b):
d2 = 0.0
for i in range(a.shape[0]):
d2 += (a[i] - b[i]) ** 2
return np.sqrt(d2)
t.stop("dist")
@njit("f4(f4[::1])", cache=True, nogil=True)
def length(a):
return np.sqrt(inner(a, a))
t.stop("length")
@njit("void(f4[:, ::1], f4[:, ::1], f4[:, :, ::1], i8, f4)", cache=True, parallel=True, nogil=True)
def expand_trimap(expanded_trimap, trimap, image, k_i, k_c):
# NB: Description in paper does not match published test images.
# The radius appears to be larger and expanded trimap is sparser.
h, w = trimap.shape
for y, x in pndindex((h, w)):
if trimap[y, x] == 0 or trimap[y, x] == 1: continue
closest = np.inf
for y2 in range(y - k_i, y + k_i + 1):
for x2 in range(x - k_i, x + k_i + 1):
if x2 < 0 or x2 >= w or y2 < 0 or y2 >= h: continue
if trimap[y2, x2] != 0 and trimap[y2, x2] != 1: continue
dr = image[y, x, 0] - image[y2, x2, 0]
dg = image[y, x, 1] - image[y2, x2, 1]
db = image[y, x, 2] - image[y2, x2, 2]
color_distance = np.sqrt(dr * dr + dg * dg + db * db)
spatial_distance = np.hypot(x - x2, y - y2)
if color_distance > k_c: continue
if spatial_distance > k_i: continue
if spatial_distance < closest:
closest = spatial_distance
expanded_trimap[y, x] = trimap[y2, x2]
t.stop("expand_trimap")
@njit("void(f4[:, :, ::1], f4[:, :, ::1], f4[:, ::1], f4[:, :, ::1], f4[:, ::1], i8, f4, f4, f4, f4, i8)", cache=True, parallel=True, nogil=True)
def sample_gathering(
gathering_F,
gathering_B,
gathering_alpha,
image,
trimap,
num_angles,
eN,
eA,
ef,
eb,
Np_radius,
):
h, w = trimap.shape
max_steps = 2 * max(w, h)
for y, x in pndindex((h, w)):
fg_samples = np.zeros((num_angles, 3), dtype=np.float32)
fg_samples_xy = np.zeros((num_angles, 2), dtype=np.int32)
bg_samples = np.zeros((num_angles, 3), dtype=np.float32)
bg_samples_xy = np.zeros((num_angles, 2), dtype=np.int32)
C_p = image[y, x]
gathering_alpha[y, x] = trimap[y, x]
if trimap[y, x] == 0:
gathering_B[y, x] = C_p
continue
if trimap[y, x] == 1:
gathering_F[y, x] = C_p
continue
# Fixed start angles in 8-by-8 grid for reproducible tests
n = 8
i = (x % n) + (y % n) * n
# Shuffle (99991 is a prime number)
i = (i * 99991) % (n * n)
start_angle = 2.0 * np.pi / (n * n) * i
num_fg_samples = 0
num_bg_samples = 0
for i in range(num_angles):
angle = 2.0 * np.pi / num_angles * i + start_angle
c = np.cos(angle)
s = np.sin(angle)
has_fg = False
has_bg = False
for step in range(max_steps):
if has_fg and has_bg: break
x2 = int(x + step * c)
y2 = int(y + step * s)
if x2 < 0 or y2 < 0 or x2 >= w or y2 >= h: break
if not has_fg and trimap[y2, x2] == 1:
fg_samples[num_fg_samples] = image[y2, x2]
fg_samples_xy[num_fg_samples, 0] = x2
fg_samples_xy[num_fg_samples, 1] = y2
num_fg_samples += 1
has_fg = True
if not has_bg and trimap[y2, x2] == 0:
bg_samples[num_bg_samples] = image[y2, x2]
bg_samples_xy[num_bg_samples, 0] = x2
bg_samples_xy[num_bg_samples, 1] = y2
num_bg_samples += 1
has_bg = True
if num_fg_samples == 0:
fg_samples[num_fg_samples] = gathering_F[y, x]
fg_samples_xy[num_fg_samples, 0] = x
fg_samples_xy[num_fg_samples, 1] = y
num_fg_samples += 1
if num_bg_samples == 0:
bg_samples[num_bg_samples] = gathering_B[y, x]
bg_samples_xy[num_bg_samples, 0] = x
bg_samples_xy[num_bg_samples, 1] = y
num_bg_samples += 1
min_Ep_f = np.inf
min_Ep_b = np.inf
for i in range(num_fg_samples):
Ep_f = Ep(image, x, y, fg_samples_xy[i, 0], fg_samples_xy[i, 1])
min_Ep_f = min(min_Ep_f, Ep_f)
for j in range(num_bg_samples):
Ep_b = Ep(image, x, y, bg_samples_xy[j, 0], bg_samples_xy[j, 1])
min_Ep_b = min(min_Ep_b, Ep_b)
PF_p = min_Ep_b / (min_Ep_f + min_Ep_b + 1e-5)
min_cost = np.inf
# Find best foreground/background pair
for i in range(num_fg_samples):
for j in range(num_bg_samples):
F = fg_samples[i]
B = bg_samples[j]
alpha_p = estimate_alpha(C_p, F, B)
Ap = PF_p + (1.0 - 2.0 * PF_p) * alpha_p
Dp_f = np.hypot(x - fg_samples_xy[i, 0], y - fg_samples_xy[i, 1])
Dp_b = np.hypot(x - bg_samples_xy[j, 0], y - bg_samples_xy[j, 1])
g_p = (
Np(image, x, y, F, B, Np_radius)**eN *
Ap**eA *
Dp_f**ef *
Dp_b**eb)
if min_cost > g_p:
min_cost = g_p
gathering_alpha[y, x] = alpha_p
gathering_F[y, x] = F
gathering_B[y, x] = B
t.stop("sample_gathering")
@njit("void(f4[:, :, ::1], f4[:, :, ::1], f4[:, ::1], f4[:, :, ::1], f4[:, :, ::1], f4[:, :, ::1], f4[:, ::1], i8)", cache=True, parallel=False, nogil=True)
def sample_refinement(
refined_F,
refined_B,
refined_alpha,
gathering_F,
gathering_B,
image,
trimap,
radius,
):
h, w = trimap.shape
refined_F[:] = gathering_F
refined_B[:] = gathering_B
refined_alpha[:] = trimap
for y, x in pndindex((h, w)):
C_p = image[y, x]
if trimap[y, x] == 0 or trimap[y, x] == 1:
continue
max_samples = 3
sample_F = np.zeros((max_samples, 3), dtype=np.float32)
sample_B = np.zeros((max_samples, 3), dtype=np.float32)
sample_cost = np.zeros(max_samples, dtype=np.float32)
sample_cost[:] = np.inf
for dy in range(-radius, radius + 1):
for dx in range(-radius, radius + 1):
x2 = x + dx
y2 = y + dy
if 0 <= x2 < w and 0 <= y2 < h:
F_q = gathering_F[y2, x2]
B_q = gathering_B[y2, x2]
cost = Mp2(C_p, F_q, B_q)
i = np.argmax(sample_cost)
if cost < sample_cost[i]:
sample_cost[i] = cost
sample_F[i] = F_q
sample_B[i] = B_q
F_mean = sample_F.sum(axis=0) / max_samples
B_mean = sample_B.sum(axis=0) / max_samples
refined_F[y, x] = F_mean
refined_B[y, x] = B_mean
refined_alpha[y, x] = estimate_alpha(C_p, F_mean, B_mean)
t.stop("sample_refinement")
@njit("void(f4[:, :, ::1], f4[:, :, ::1], f4[:, ::1], f4[:, :, ::1], f4[:, :, ::1], f4[:, ::1], f4[:, :, ::1], f4[:, ::1], i8, i8, i8, f4, f4, f4)", cache=True, parallel=True, nogil=True)
def local_smoothing(
final_F,
final_B,
final_alpha,
refined_F,
refined_B,
refined_alpha,
image,
trimap,
radius1,
radius2,
radius3,
sigma_sq1,
sigma_sq2,
sigma_sq3,
):
h, w = trimap.shape
final_confidence = np.zeros((h, w), dtype=np.float32)
W_FB = np.zeros((h, w), dtype=np.float32)
low_frequency_alpha = np.zeros((h, w), dtype=np.float32)
final_F[:] = refined_F
final_B[:] = refined_B
for y, x in pndindex((h, w)):
C_p = image[y, x]
if trimap[y, x] == 0 or trimap[y, x] == 1:
continue
F_p = np.zeros(3, dtype=np.float32)
B_p = np.zeros(3, dtype=np.float32)
sum_F = 0.0
sum_B = 0.0
alpha_p = refined_alpha[y, x]
for dy in range(-radius1, radius1 + 1):
for dx in range(-radius1, radius1 + 1):
x2 = x + dx
y2 = y + dy
if 0 <= x2 < w and 0 <= y2 < h:
# NB: Gaussian not normalized, not using confidence
Wc_pq = np.exp(-1.0 / sigma_sq1 * (dx * dx + dy * dy))
if x != x2 or y != y2:
Wc_pq *= abs(refined_alpha[y, x] - refined_alpha[y2, x2])
alpha_q = refined_alpha[y2, x2]
W_F = Wc_pq * alpha_q
W_B = Wc_pq * (1.0 - alpha_q)
W_F = max(W_F, 1e-5)
W_B = max(W_B, 1e-5)
sum_F += W_F
sum_B += W_B
for c in range(3):
F_p[c] += W_F * refined_F[y2, x2, c]
B_p[c] += W_B * refined_B[y2, x2, c]
F_p /= sum_F
B_p /= sum_B
final_F[y, x] = F_p
final_B[y, x] = B_p
final_alpha[y, x] = estimate_alpha(C_p, F_p, B_p)
# NB: Not using confidence
W_FB[y, x] = alpha_p * (1.0 - alpha_p)
for y, x in pndindex((h, w)):
C_p = image[y, x]
if trimap[y, x] == 0 or trimap[y, x] == 1:
final_confidence[y, x] = trimap[y, x]
continue
D_FB = 0.0
weight_sum = 0.0
for dy in range(-radius2, radius2 + 1):
for dx in range(-radius2, radius2 + 1):
x2 = x + dx
y2 = y + dy
if 0 <= x2 < w and 0 <= y2 < h:
weight_sum += W_FB[y2, x2]
D_FB += W_FB[y2, x2] * dist(final_F[y2, x2], final_B[y2, x2])
D_FB /= weight_sum + 1e-5
D_FB += 1e-5
FB_dist = dist(final_F[y, x], final_B[y, x])
F_p = final_F[y, x]
B_p = final_B[y, x]
final_confidence[y, x] = min(1.0, FB_dist / D_FB) * np.exp(-1.0 / sigma_sq2 * np.sqrt(Mp2(C_p, F_p, B_p)))
for y, x in pndindex((h, w)):
if trimap[y, x] == 0 or trimap[y, x] == 1:
final_alpha[y, x] = trimap[y, x]
continue
alpha_sum = 0.0
weight_sum = 0.0
for dy in range(-radius3, radius3 + 1):
for dx in range(-radius3, radius3 + 1):
x2 = x + dx
y2 = y + dy
if 0 <= x2 < w and 0 <= y2 < h:
# NB: Gaussian not normalized, not using final_confidence(x2, y2)
D_image_squared = dx * dx + dy * dy
is_known = trimap[y2, x2] == 0 or trimap[y2, x2] == 1
W_alpha = np.exp(-1.0 / sigma_sq3 * (dx * dx + dy * dy)) + is_known
alpha_sum += W_alpha * refined_alpha[y2, x2]
weight_sum += W_alpha
low_frequency_alpha[y, x] = alpha_sum / weight_sum
final_alpha[y, x] = final_confidence[y, x] * final_alpha[y, x] + (1.0 - final_confidence[y, x]) * low_frequency_alpha[y, x]
t.stop("local_smoothing")
# Benchmark matmul to get CPU baseline performance
n = 1024
A = np.random.rand(n, n)
t.stop("rand")
for _ in range(10):
A @ A
t.stop("matmul")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment