Skip to content

Instantly share code, notes, and snippets.

@yberreby
Created June 21, 2025 19:51
Show Gist options
  • Save yberreby/c684f23a59d931ae9fe09415052176dd to your computer and use it in GitHub Desktop.
Save yberreby/c684f23a59d931ae9fe09415052176dd to your computer and use it in GitHub Desktop.
# %% [markdown]
# ## Diffeomorphic T → Y morph with OT loss (ott-jax)
# %%
# %matplotlib widget
# %% 0 · Imports ---------------------------------------------------------------
import numpy as np, matplotlib.pyplot as plt
from skimage.draw import line
from skimage.measure import find_contours
from scipy.ndimage import gaussian_filter
import jax, jax.numpy as jnp
import optax
from ott.geometry import pointcloud
from ott.solvers import linear
import ipywidgets as widgets
from IPython.display import display
from jaxtyping import Float, Array
# %% 1 · Draw glyphs -----------------------------------------------------------
def draw_T(img):
rr, cc = line(10, 10, 10, 90); img[rr, cc] = 1
rr, cc = line(10, 50, 90, 50); img[rr, cc] = 1; return img
def draw_Y(img):
rr, cc = line(10, 20, 50, 50); img[rr, cc] = 1
rr, cc = line(10, 80, 50, 50); img[rr, cc] = 1
rr, cc = line(50, 50, 90, 50); img[rr, cc] = 1; return img
shape = (96, 96)
T_bin = gaussian_filter(draw_T(np.zeros(shape)), 1.2) > 0.1
Y_bin = gaussian_filter(draw_Y(np.zeros(shape)), 1.2) > 0.1
# %% 2 · Point-cloud sampling --------------------------------------------------
def sample_pts(mask, n=1500, seed=0):
pts = np.column_stack(mask.nonzero())
if len(pts) > n:
rng = np.random.default_rng(seed)
pts = pts[rng.choice(len(pts), n, replace=False)]
return jnp.array(pts, jnp.float32)
xs = sample_pts(T_bin)
ys = sample_pts(Y_bin)
# %% 3 · Bilinear sampler ------------------------------------------------------
def bilinear(img: Float[Array, "H W"], xy: Float[Array, "N 2"]) -> Float[Array, "N"]:
H, W = img.shape
y, x = xy[:, 0], xy[:, 1]
y0, x0 = jnp.floor(y).astype(jnp.int32), jnp.floor(x).astype(jnp.int32)
y1, x1 = jnp.clip(y0 + 1, 0, H - 1), jnp.clip(x0 + 1, 0, W - 1)
y0, x0 = jnp.clip(y0, 0, H - 1), jnp.clip(x0, 0, W - 1)
wa = (y1 - y) * (x1 - x); wb = (y1 - y) * (x - x0)
wc = (y - y0) * (x1 - x); wd = (y - y0) * (x - x0)
Ia = img[y0, x0]; Ib = img[y0, x1]; Ic = img[y1, x0]; Id = img[y1, x1]
return wa*Ia + wb*Ib + wc*Ic + wd*Id
# %% 4 · Stationary-velocity flow ---------------------------------------------
@jax.jit
def integrate(v: Float[Array, "H W 2"], pts: Float[Array, "N 2"],
t: float = 1.0, steps: int = 32) -> Float[Array, "N 2"]:
dt = t / steps
coords = pts
for _ in range(steps):
vy = bilinear(v[..., 0], coords)
vx = bilinear(v[..., 1], coords)
coords = coords + dt * jnp.stack([vy, vx], -1)
return coords
# %% 5 · OT loss + smoothness --------------------------------------------------
def smooth_penalty(v):
dy = jnp.diff(v, axis=0); dx = jnp.diff(v, axis=1)
return jnp.mean(dy**2) + jnp.mean(dx**2)
def ot_loss(v, ε=0.05, λ=1e-2):
x_warp = integrate(v, xs)
geom = pointcloud.PointCloud(x_warp, ys, epsilon=ε)
ot_out = linear.solve(geom, max_iterations=200)
return ot_out.reg_ot_cost + λ*smooth_penalty(v)
# %% 6 · Optimise velocity field ----------------------------------------------
v = jax.random.normal(jax.random.key(0), (*shape, 2), jnp.float32) # parameters
opt = optax.adam(1e-2); opt_state = opt.init(v)
@jax.jit
def step(v, state):
loss, g = jax.value_and_grad(ot_loss)(v)
upd, state = opt.update(g, state); v = optax.apply_updates(v, upd)
return v, state, loss
for it in range(3000):
v, opt_state, L = step(v, opt_state)
if it % 50 == 0: print(f"iter {it:3d} · loss {L:.4f}")
# %% 7 · Generate morph frames -------------------------------------------------
Ts = 64; tvec = jnp.linspace(0, 1, Ts)
frames_pts = jax.vmap(lambda τ: integrate(v, xs, τ))(tvec)
def raster(pts):
m = np.zeros(shape, np.float32)
ij = np.round(np.asarray(pts)).astype(int)
keep = (0 <= ij[:,0]) & (ij[:,0] < shape[0]) & (0 <= ij[:,1]) & (ij[:,1] < shape[1])
m[ij[keep,0], ij[keep,1]] = 1
return gaussian_filter(m, 0.8) > 0.1
frames = np.stack([raster(p) for p in np.array(frames_pts)])
# %% 8 · Interactive slider ----------------------------------------------------
fig, ax = plt.subplots(figsize=(3.5, 3.5))
im = ax.imshow(frames[0], cmap='gray', vmin=0, vmax=1); ax.axis('off')
slider = widgets.IntSlider(0, 0, Ts-1, description='t')
def _upd(c): im.set_data(frames[c['new']]); fig.canvas.draw_idle()
slider.observe(_upd, names='value'); display(slider); plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment