Last active
May 25, 2021 13:17
-
-
Save arquolo/2cc04195e9b1597df1f453fbd383b037 to your computer and use it in GitHub Desktop.
Fork from cheind/py-thin-plate-spline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import numpy as np | |
class TPS: | |
@staticmethod | |
def fit(c: np.ndarray, delta: np.ndarray, lambd: float = 0., reduced: bool = False) -> np.ndarray: | |
# c: (N, 2), delta: (N, 2) -> (N + 2, 2) or (N + 3, 2) | |
n = len(c) | |
k = TPS.ud(c, c) | |
k += np.eye(n, dtype='f4') * lambd # (N, N) | |
p = np.ones((n, 3), dtype='f4') | |
p[:, 1:] = c | |
a = np.zeros((n + 3, n + 3), dtype='f4') | |
a[:n, :n] = k | |
a[:n, -3:] = p | |
a[-3:, :n] = p.T | |
v = np.zeros((n + 3, 2), dtype='f4') | |
v[:n] = delta | |
theta = np.linalg.solve(a, v) # (n + 3, 2) | |
return theta[1:] if reduced else theta | |
@staticmethod | |
def ud(a: np.ndarray, b: np.ndarray) -> np.ndarray: | |
# a: (N, 2), b: (M, 2) -> (N, M) | |
r2 = np.square(a[:, None] - b[None, :]).sum(-1) | |
return 0.5 * r2 * np.log(r2 + 1e-12) | |
@staticmethod | |
def ud2(dy: np.ndarray, dx: np.ndarray, dst: np.ndarray) -> np.ndarray: | |
# dy: (H, 1, 1), dx: (1, W, 1), dst: (N, 2) | |
r2y = np.square(dy - dst[..., 1]) # (H, 1, N) | |
r2x = np.square(dx - dst[..., 0]) # (1, W, N) | |
r2 = r2y + r2x # (H, W, N) | |
return 0.5 * r2 * np.log(r2 + 1e-12) | |
@staticmethod | |
def z(dy: np.ndarray, dx: np.ndarray, dst: np.ndarray, theta: np.ndarray) -> np.ndarray: | |
# dy: (H, 1, 1), dx: (1, W, 1), dst: (N, 2), theta (N+3?, 2) -> (H, W, 2) | |
u = TPS.ud2(dy, dx, dst) # (H, W, N) | |
w = theta[:-3] # (N, 2) | |
if theta.shape[0] == dst.shape[0] + 2: | |
w = np.concatenate((-w.sum(0, keepdims=True), w), axis=0) | |
b = np.dot(u, w) # (H, W, 2) | |
a0, a1, a2 = theta[-3:] # (3, 2) | |
# (2) + (2) x (H, 1, 1) + (2) x (1, W, 1) + (H, W, 2) | |
return a0 + a1 * dy + a2 * dx + b | |
def tps_warp(src, dst, src_hw: tuple[int, int], dst_hw: tuple[int, int], | |
reduced: bool = False, | |
pool: int = 1) -> tuple[np.ndarray, np.ndarray]: | |
# src: (N, 2), dst: (N, 2) -> pair of (H', W') | |
# To trade precision via performance, use pool > 1 | |
theta = TPS.fit(dst, src - dst, reduced=reduced) | |
h2, w2 = dst_hw | |
h2p, w2p = h2 // pool, w2 // pool # Use lower scale for perf | |
dy = np.linspace(0, 1, h2p, dtype='f4')[:, None, None] # (H', 1, 1) | |
dx = np.linspace(0, 1, w2p, dtype='f4')[None, :, None] # (1, W', 1) | |
grid = TPS.z(dy, dx, dst, theta) # (H', W', 2) | |
# Restore offset | |
grid[..., 1] += dy.squeeze(-1) | |
grid[..., 0] += dx.squeeze(-1) | |
# Restore scale | |
if pool > 1: | |
grid = cv2.resize(grid, (w2, h2), interpolation=cv2.INTER_CUBIC) | |
h1, w1 = src_hw | |
my = (grid[..., 1] * h1).astype('f4') | |
mx = (grid[..., 0] * w1).astype('f4') | |
return mx, my |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment