Skip to content

Instantly share code, notes, and snippets.

@yberreby
Created June 21, 2025 19:49
Show Gist options
  • Save yberreby/bf1c768537c9b0c5aef7c163cfd4f793 to your computer and use it in GitHub Desktop.
Save yberreby/bf1c768537c9b0c5aef7c163cfd4f793 to your computer and use it in GitHub Desktop.
# %%
# %matplotlib widget
# %%
# --- Imports ---------------------------------------------------------------
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import diags
from skimage.measure import find_contours
from scipy.ndimage import gaussian_filter
# %%
# --- 1. Prepare example glyph boundaries (replace with your data) -----------
def binarize(img):
img = gaussian_filter(img, 1.2)
return (img > 0.1).astype(np.float32)
# %%
def get_curve(img):
# Returns closed Nx2 array (y, x) of boundary points
c = max(find_contours(img, .5), key=len)
return jnp.array(c, jnp.float32)
# %%
def make_T(shape=(96,96)):
img = np.zeros(shape, np.float32)
img[10:14, 10:86] = 1
img[14:86, 45:51] = 1
return binarize(img)
# %%
def make_Y(shape=(96,96)):
img = np.zeros(shape, np.float32)
for a, b in [((10,20),(50,48)), ((10,75),(50,48)), ((50,48),(86,48))]:
rr, cc = np.linspace(a[0],b[0],60).astype(int), np.linspace(a[1],b[1],60).astype(int)
img[rr,cc] = 1
return binarize(img)
# %%
curve_A = get_curve(make_T())
curve_B = get_curve(make_Y())
# %%
# --- 2. Reparametrize: Sample N points along arc length ---------------------
def uniform_sample(curve, N=256):
dists = np.sqrt(np.sum(np.diff(curve, axis=0)**2, axis=1))
cumulative = np.concatenate([[0], np.cumsum(dists)])
total = cumulative[-1]
samples = np.linspace(0, total, N, endpoint=False)
interp = lambda arr: np.interp(samples, cumulative, arr)
return jnp.stack([interp(curve[:,0]), interp(curve[:,1])], axis=1)
N = 256
curve_A = uniform_sample(curve_A, N)
curve_B = uniform_sample(curve_B, N)
# %%
# --- 3. Build Graph Laplacian for a closed chain ---------------------------
def laplacian_matrix(N):
# Standard ring Laplacian (second difference, periodic)
diagonals = [2*np.ones(N), -1*np.ones(N), -1*np.ones(N)]
offsets = [0, -1, 1]
L = diags(diagonals, offsets, shape=(N,N)).toarray()
L[0,-1] = L[-1,0] = -1
return jnp.array(L, jnp.float32)
# %%
L = laplacian_matrix(N)
# %%
# --- 4. Compute first k Laplacian eigenfunctions ---------------------------
def spectral_basis(L, k=16):
vals, vecs = np.linalg.eigh(np.array(L))
idx = np.argsort(vals)
return jnp.array(vecs[:, idx[:k]]) # shape N x k
# %%
k = 32
phi_A = spectral_basis(L, k) # Both curves use same connectivity
# %%
# --- 5. Spectral coordinates for each point ---------------------------------
def spectral_embedding(curve, basis):
# Project the x and y coordinates onto the Laplacian eigenbasis
coeffs_x = basis.T @ curve[:,1] # N x k -> k
coeffs_y = basis.T @ curve[:,0]
return jnp.stack([coeffs_y, coeffs_x], axis=1) # shape k x 2
# %%
emb_A = spectral_embedding(curve_A, phi_A)
emb_B = spectral_embedding(curve_B, phi_A)
# %%
# --- 6. Morph between curves in spectral space ------------------------------
def morph_spectral(emb_A, emb_B, t):
return (1-t)*emb_A + t*emb_B
# %%
def reconstruct_curve(emb, basis):
# Reconstruct curve from spectral embedding
y = basis @ emb[:,0]
x = basis @ emb[:,1]
return jnp.stack([y, x], axis=1)
# %%
# --- 7. Animate morph -------------------------------------------------------
fig, ax = plt.subplots(figsize=(4,4))
ax.axis("equal")
ax.axis("off")
line, = ax.plot([], [], lw=2, c='k')
ax.plot(curve_A[:,1], curve_A[:,0], '--', alpha=0.2)
ax.plot(curve_B[:,1], curve_B[:,0], '--', alpha=0.2)
# %%
def update(t):
emb_t = morph_spectral(emb_A, emb_B, t)
curve_t = reconstruct_curve(emb_t, phi_A)
line.set_data(curve_t[:,1], curve_t[:,0])
fig.canvas.draw_idle()
# %%
from ipywidgets import FloatSlider, interact
slider = FloatSlider(value=0, min=0, max=1., step=0.01, description='t')
def on_change(t): update(t)
interact(on_change, t=slider)
update(0)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment