Created
June 21, 2025 19:51
-
-
Save yberreby/c684f23a59d931ae9fe09415052176dd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% [markdown] | |
# ## Diffeomorphic T → Y morph with OT loss (ott-jax) | |
# %% | |
# %matplotlib widget | |
# %% 0 · Imports --------------------------------------------------------------- | |
import numpy as np, matplotlib.pyplot as plt | |
from skimage.draw import line | |
from skimage.measure import find_contours | |
from scipy.ndimage import gaussian_filter | |
import jax, jax.numpy as jnp | |
import optax | |
from ott.geometry import pointcloud | |
from ott.solvers import linear | |
import ipywidgets as widgets | |
from IPython.display import display | |
from jaxtyping import Float, Array | |
# %% 1 · Draw glyphs ----------------------------------------------------------- | |
def draw_T(img): | |
rr, cc = line(10, 10, 10, 90); img[rr, cc] = 1 | |
rr, cc = line(10, 50, 90, 50); img[rr, cc] = 1; return img | |
def draw_Y(img): | |
rr, cc = line(10, 20, 50, 50); img[rr, cc] = 1 | |
rr, cc = line(10, 80, 50, 50); img[rr, cc] = 1 | |
rr, cc = line(50, 50, 90, 50); img[rr, cc] = 1; return img | |
shape = (96, 96) | |
T_bin = gaussian_filter(draw_T(np.zeros(shape)), 1.2) > 0.1 | |
Y_bin = gaussian_filter(draw_Y(np.zeros(shape)), 1.2) > 0.1 | |
# %% 2 · Point-cloud sampling -------------------------------------------------- | |
def sample_pts(mask, n=1500, seed=0): | |
pts = np.column_stack(mask.nonzero()) | |
if len(pts) > n: | |
rng = np.random.default_rng(seed) | |
pts = pts[rng.choice(len(pts), n, replace=False)] | |
return jnp.array(pts, jnp.float32) | |
xs = sample_pts(T_bin) | |
ys = sample_pts(Y_bin) | |
# %% 3 · Bilinear sampler ------------------------------------------------------ | |
def bilinear(img: Float[Array, "H W"], xy: Float[Array, "N 2"]) -> Float[Array, "N"]: | |
H, W = img.shape | |
y, x = xy[:, 0], xy[:, 1] | |
y0, x0 = jnp.floor(y).astype(jnp.int32), jnp.floor(x).astype(jnp.int32) | |
y1, x1 = jnp.clip(y0 + 1, 0, H - 1), jnp.clip(x0 + 1, 0, W - 1) | |
y0, x0 = jnp.clip(y0, 0, H - 1), jnp.clip(x0, 0, W - 1) | |
wa = (y1 - y) * (x1 - x); wb = (y1 - y) * (x - x0) | |
wc = (y - y0) * (x1 - x); wd = (y - y0) * (x - x0) | |
Ia = img[y0, x0]; Ib = img[y0, x1]; Ic = img[y1, x0]; Id = img[y1, x1] | |
return wa*Ia + wb*Ib + wc*Ic + wd*Id | |
# %% 4 · Stationary-velocity flow --------------------------------------------- | |
@jax.jit | |
def integrate(v: Float[Array, "H W 2"], pts: Float[Array, "N 2"], | |
t: float = 1.0, steps: int = 32) -> Float[Array, "N 2"]: | |
dt = t / steps | |
coords = pts | |
for _ in range(steps): | |
vy = bilinear(v[..., 0], coords) | |
vx = bilinear(v[..., 1], coords) | |
coords = coords + dt * jnp.stack([vy, vx], -1) | |
return coords | |
# %% 5 · OT loss + smoothness -------------------------------------------------- | |
def smooth_penalty(v): | |
dy = jnp.diff(v, axis=0); dx = jnp.diff(v, axis=1) | |
return jnp.mean(dy**2) + jnp.mean(dx**2) | |
def ot_loss(v, ε=0.05, λ=1e-2): | |
x_warp = integrate(v, xs) | |
geom = pointcloud.PointCloud(x_warp, ys, epsilon=ε) | |
ot_out = linear.solve(geom, max_iterations=200) | |
return ot_out.reg_ot_cost + λ*smooth_penalty(v) | |
# %% 6 · Optimise velocity field ---------------------------------------------- | |
v = jax.random.normal(jax.random.key(0), (*shape, 2), jnp.float32) # parameters | |
opt = optax.adam(1e-2); opt_state = opt.init(v) | |
@jax.jit | |
def step(v, state): | |
loss, g = jax.value_and_grad(ot_loss)(v) | |
upd, state = opt.update(g, state); v = optax.apply_updates(v, upd) | |
return v, state, loss | |
for it in range(3000): | |
v, opt_state, L = step(v, opt_state) | |
if it % 50 == 0: print(f"iter {it:3d} · loss {L:.4f}") | |
# %% 7 · Generate morph frames ------------------------------------------------- | |
Ts = 64; tvec = jnp.linspace(0, 1, Ts) | |
frames_pts = jax.vmap(lambda τ: integrate(v, xs, τ))(tvec) | |
def raster(pts): | |
m = np.zeros(shape, np.float32) | |
ij = np.round(np.asarray(pts)).astype(int) | |
keep = (0 <= ij[:,0]) & (ij[:,0] < shape[0]) & (0 <= ij[:,1]) & (ij[:,1] < shape[1]) | |
m[ij[keep,0], ij[keep,1]] = 1 | |
return gaussian_filter(m, 0.8) > 0.1 | |
frames = np.stack([raster(p) for p in np.array(frames_pts)]) | |
# %% 8 · Interactive slider ---------------------------------------------------- | |
fig, ax = plt.subplots(figsize=(3.5, 3.5)) | |
im = ax.imshow(frames[0], cmap='gray', vmin=0, vmax=1); ax.axis('off') | |
slider = widgets.IntSlider(0, 0, Ts-1, description='t') | |
def _upd(c): im.set_data(frames[c['new']]); fig.canvas.draw_idle() | |
slider.observe(_upd, names='value'); display(slider); plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment