Skip to content

Instantly share code, notes, and snippets.

@yberreby
Created June 23, 2025 22:49
Show Gist options
  • Save yberreby/4c8a8370293a7cd092bcfca3603d57d3 to your computer and use it in GitHub Desktop.
Save yberreby/4c8a8370293a7cd092bcfca3603d57d3 to your computer and use it in GitHub Desktop.
Optimization of a stationary velocity field with Adam in order to morph a shape into another. JAX-based interactive notebook, shared in `jupytext` format.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.17.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
# %% [markdown]
# # Direct Velocity Field Optimization
#
# Optimize a velocity field represented as a grid.
# %%
# %matplotlib widget
# %%
import jax
import jax.numpy as jnp
from jax import vmap, jit, grad, value_and_grad
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from typing import Tuple, Optional
from jaxtyping import Array, Float, ScalarLike, jaxtyped
from beartype import beartype as typechecker
from functools import partial
import optax
# %%
class Config:
"""Central configuration for the experiment."""
# Shapes
N_POINTS = 60
CIRCLE_RADIUS = 0.8
SQUARE_SIZE = 1.4
# Velocity field grid
GRID_SIZE = 20 # NxN grid of velocity vectors
GRID_EXTENT = 3.0 # Grid covers [-X, X] x [-X, X]
# Integration
N_INTEGRATION_STEPS = 20
# Training
LEARNING_RATE = 0.01
N_STEPS = 200
REG_WEIGHT = 0.1
# Visualization
VIZ_EVERY = 1 # Update visualization every N steps
ARROW_SCALE = 0.3
FIG_SIZE = (15, 5)
# Set random seed
key = jax.random.PRNGKey(42)
# %% [markdown]
# ## Shape Generation
# %%
@partial(jit, static_argnames=['n_points'])
def create_circle(n_points: int, radius: ScalarLike , center: Float[Array, "2"]) -> Float[Array, "n 2"]:
theta = jnp.linspace(0, 2 * jnp.pi, n_points, endpoint=False)
points = radius * jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1)
return points + center
@partial(jit, static_argnames=['n_points'])
def create_square(n_points: int, size: ScalarLike, center: Float[Array, "2"]) -> Float[Array, "n 2"]:
n_side = n_points // 4
half_size = size / 2
# Generate points for each side
t = jnp.linspace(-half_size, half_size, n_side, endpoint=False)
# Four sides of the square
sides = [
jnp.stack([t, jnp.full_like(t, -half_size)], axis=1), # Bottom
jnp.stack([jnp.full_like(t, half_size), t], axis=1), # Right
jnp.stack([jnp.flip(t), jnp.full_like(t, half_size)], axis=1), # Top
jnp.stack([jnp.full_like(t, -half_size), jnp.flip(t)], axis=1) # Left
]
points = jnp.vstack(sides)
return points + center
# Create source and target shapes
source_shape = create_circle(Config.N_POINTS, Config.CIRCLE_RADIUS, center=jnp.zeros(2))
target_shape = create_square(Config.N_POINTS, Config.SQUARE_SIZE, center=jnp.zeros(2))
# %% [markdown]
# ## Grid-based Velocity Field
# %%
def create_velocity_grid(key: jax.random.PRNGKey, grid_size: int, scale: float = 0.1) -> Float[Array, "grid_size grid_size 2"]:
"""Randomly initialize a grid of velocity vectors."""
return scale * jax.random.normal(key, (grid_size, grid_size, 2))
def create_position_grid(grid_size: int, extent: float) -> Tuple[Float[Array, "grid_size grid_size"], Float[Array, "grid_size grid_size"]]:
"""Create the x,y coordinates for the velocity grid."""
x = jnp.linspace(-extent, extent, grid_size)
y = jnp.linspace(-extent, extent, grid_size)
return jnp.meshgrid(x, y)
@jit
def interpolate_velocity(
point: Float[Array, "2"],
velocity_grid: Float[Array, "grid_size grid_size 2"],
grid_extent: float
) -> Float[Array, "2"]:
"""Bilinear interpolation of velocity at a point from the grid."""
grid_size = velocity_grid.shape[0]
# Map point to grid coordinates [0, grid_size-1]
grid_x = (point[0] + grid_extent) / (2 * grid_extent) * (grid_size - 1)
grid_y = (point[1] + grid_extent) / (2 * grid_extent) * (grid_size - 1)
# Get integer indices and fractional parts
x0 = jnp.floor(grid_x).astype(jnp.int32)
y0 = jnp.floor(grid_y).astype(jnp.int32)
x1 = x0 + 1
y1 = y0 + 1
# Clamp to grid bounds
x0 = jnp.clip(x0, 0, grid_size - 1)
x1 = jnp.clip(x1, 0, grid_size - 1)
y0 = jnp.clip(y0, 0, grid_size - 1)
y1 = jnp.clip(y1, 0, grid_size - 1)
# Fractional parts
fx = grid_x - jnp.floor(grid_x)
fy = grid_y - jnp.floor(grid_y)
# Bilinear interpolation
v00 = velocity_grid[y0, x0]
v01 = velocity_grid[y0, x1]
v10 = velocity_grid[y1, x0]
v11 = velocity_grid[y1, x1]
v0 = (1 - fx) * v00 + fx * v01
v1 = (1 - fx) * v10 + fx * v11
v = (1 - fy) * v0 + fy * v1
return v
# %% [markdown]
# ## ODE Integration
# %%
@partial(jit, static_argnames=['n_steps'])
def integrate_velocity_field(
point: Float[Array, "2"],
velocity_grid: Float[Array, "grid_size grid_size 2"],
n_steps: int,
grid_extent: float
) -> Float[Array, "2"]:
"""Integrate the velocity field ODE using forward Euler method."""
dt = 1.0 / n_steps
def step(pos, _):
vel = interpolate_velocity(pos, velocity_grid, grid_extent)
new_pos = pos + vel * dt
return new_pos, pos
final_pos, trajectory = jax.lax.scan(step, point, jnp.arange(n_steps))
return final_pos
# %% [markdown]
# ## Loss Functions
# %%
def chamfer_distance(points1: Float[Array, "n 2"], points2: Float[Array, "m 2"]) -> Float[Array, ""]:
"""Compute bidirectional Chamfer distance between two point sets."""
# Compute pairwise squared distances
dists_sq = jnp.sum((points1[:, None, :] - points2[None, :, :]) ** 2, axis=2)
# Nearest neighbor distances in both directions
nn1_to_2 = jnp.min(dists_sq, axis=1).mean()
nn2_to_1 = jnp.min(dists_sq, axis=0).mean()
return nn1_to_2 + nn2_to_1
def smoothness_regularization(velocity_grid: Float[Array, "grid_size grid_size 2"]) -> Float[Array, ""]:
"""Regularization to encourage smooth velocity fields."""
# Compute differences between adjacent grid points
dx = velocity_grid[1:, :, :] - velocity_grid[:-1, :, :]
dy = velocity_grid[:, 1:, :] - velocity_grid[:, :-1, :]
# Penalize large gradients
return jnp.mean(dx**2) + jnp.mean(dy**2)
def create_loss_fn(source_points, target_points, config):
"""Create the full loss function."""
@jit
def loss_fn(velocity_grid):
# Transform source points
transform_fn = partial(integrate_velocity_field,
n_steps=config.N_INTEGRATION_STEPS,
grid_extent=config.GRID_EXTENT)
transformed = vmap(lambda p: transform_fn(p, velocity_grid))(source_points)
# Data term: match target shape
data_loss = chamfer_distance(transformed, target_points)
# Regularization term: smooth velocity field
reg_loss = smoothness_regularization(velocity_grid)
total_loss = data_loss + config.REG_WEIGHT * reg_loss
return total_loss, (data_loss, reg_loss, transformed)
return loss_fn
# %% [markdown]
# ## Interactive Training Visualization
# %%
config = Config
source_points = source_shape
target_points = target_shape
# %%
# Initialize velocity grid
velocity_grid = create_velocity_grid(key, config.GRID_SIZE, scale=0.1)
# Create loss function and optimizer
loss_fn = create_loss_fn(source_points, target_points, config)
optimizer = optax.adam(config.LEARNING_RATE)
opt_state = optimizer.init(velocity_grid)
# Storage for metrics
metrics = {'loss': [], 'data_loss': [], 'reg_loss': []}
# Setup interactive figure
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=config.FIG_SIZE)
# Get grid positions
X, Y = create_position_grid(config.GRID_SIZE, config.GRID_EXTENT)
# Initialize plot elements
# Ax1: Velocity field
quiver = ax1.quiver(X, Y, velocity_grid[:, :, 0], velocity_grid[:, :, 1],
scale=1/config.ARROW_SCALE, alpha=0.6)
source_scatter1 = ax1.scatter(source_points[:, 0], source_points[:, 1], c='blue', s=20, alpha=0.5, label='Source')
target_scatter1 = ax1.scatter(target_points[:, 0], target_points[:, 1], c='red', s=20, alpha=0.5, label='Target')
ax1.set_title('Velocity Field')
ax1.set_aspect('equal')
ax1.set_xlim(-config.GRID_EXTENT, config.GRID_EXTENT)
ax1.set_ylim(-config.GRID_EXTENT, config.GRID_EXTENT)
ax1.legend()
# Ax2: Deformed grid - initialize with straight lines
viz_grid_size = 30
x_fine = jnp.linspace(-config.GRID_EXTENT, config.GRID_EXTENT, viz_grid_size)
y_fine = jnp.linspace(-config.GRID_EXTENT, config.GRID_EXTENT, viz_grid_size)
# Store line objects
grid_lines = []
# Vertical lines
for i, x in enumerate(x_fine[::3]):
line_points = jnp.stack([jnp.full_like(y_fine, x), y_fine], axis=1)
line, = ax2.plot(line_points[:, 0], line_points[:, 1], 'gray', alpha=0.3, linewidth=0.8)
grid_lines.append(('v', i, line))
# Horizontal lines
for i, y in enumerate(y_fine[::3]):
line_points = jnp.stack([x_fine, jnp.full_like(x_fine, y)], axis=1)
line, = ax2.plot(line_points[:, 0], line_points[:, 1], 'gray', alpha=0.3, linewidth=0.8)
grid_lines.append(('h', i, line))
target_scatter2 = ax2.scatter(target_points[:, 0], target_points[:, 1], c='red', s=20, alpha=0.5, label='Target')
transformed_scatter = ax2.scatter([], [], c='green', s=30, label='Transformed')
ax2.set_title('Deformed Grid')
ax2.set_aspect('equal')
ax2.set_xlim(-config.GRID_EXTENT, config.GRID_EXTENT)
ax2.set_ylim(-config.GRID_EXTENT, config.GRID_EXTENT)
ax2.legend()
# Ax3: Velocity magnitude
velocity_magnitude = jnp.sqrt(velocity_grid[:, :, 0]**2 + velocity_grid[:, :, 1]**2)
im = ax3.imshow(velocity_magnitude, extent=[-config.GRID_EXTENT, config.GRID_EXTENT,
-config.GRID_EXTENT, config.GRID_EXTENT],
origin='lower', cmap='viridis', vmin=0, vmax=2)
ax3.scatter(source_points[:, 0], source_points[:, 1], c='blue', s=20, alpha=0.5)
ax3.scatter(target_points[:, 0], target_points[:, 1], c='red', s=20, alpha=0.5)
ax3.set_title('Velocity Magnitude')
ax3.set_aspect('equal')
plt.colorbar(im, ax=ax3, label='|v|')
fig.suptitle('Training Progress')
plt.tight_layout()
# %%
# Train the model with interactive visualization
key, subkey = jax.random.split(key)
# Training loop with live updates
transform_fn = partial(integrate_velocity_field,
n_steps=config.N_INTEGRATION_STEPS,
grid_extent=config.GRID_EXTENT)
for step in range(config.N_STEPS):
# Compute loss and gradients
(loss, (data_loss, reg_loss, transformed)), grads = value_and_grad(loss_fn, has_aux=True)(velocity_grid)
# Update parameters
updates, opt_state = optimizer.update(grads, opt_state)
velocity_grid = optax.apply_updates(velocity_grid, updates)
# Record metrics
metrics['loss'].append(float(loss))
metrics['data_loss'].append(float(data_loss))
metrics['reg_loss'].append(float(reg_loss))
# Update visualization
if step % config.VIZ_EVERY == 0 or step == config.N_STEPS - 1:
# Update velocity field arrows
quiver.set_UVC(velocity_grid[:, :, 0], velocity_grid[:, :, 1])
# Update deformed grid
for line_type, idx, line in grid_lines:
if line_type == 'v':
x = x_fine[::3][idx]
line_points = jnp.stack([jnp.full_like(y_fine, x), y_fine], axis=1)
else: # 'h'
y = y_fine[::3][idx]
line_points = jnp.stack([x_fine, jnp.full_like(x_fine, y)], axis=1)
transformed_line = vmap(lambda p: transform_fn(p, velocity_grid))(line_points)
line.set_data(transformed_line[:, 0], transformed_line[:, 1])
# Update transformed points
transformed_scatter.set_offsets(transformed)
# Update velocity magnitude
velocity_magnitude = jnp.sqrt(velocity_grid[:, :, 0]**2 + velocity_grid[:, :, 1]**2)
im.set_data(velocity_magnitude)
im.set_clim(vmin=0, vmax=velocity_magnitude.max())
# Update titles
ax1.set_title(f'Velocity Field (Step {step})')
ax2.set_title(f'Deformed Grid (Step {step})')
ax3.set_title(f'Velocity Magnitude (Step {step})')
fig.suptitle(f'Step {step} | Loss: {loss:.4f} | Data: {data_loss:.4f} | Reg: {reg_loss:.4f}')
# Draw updates
fig.canvas.draw()
if step % (config.VIZ_EVERY * 5) == 0:
print(f"Step {step:3d} | Loss: {loss:.4f} | Data: {data_loss:.4f} | Reg: {reg_loss:.4f}")
# %% [markdown]
# ## Plot Training Metrics
# %%
def plot_training_metrics(metrics):
"""Plot training loss curves."""
fig, ax = plt.subplots(figsize=(8, 5))
steps = np.arange(len(metrics['loss']))
ax.plot(steps, metrics['loss'], 'b-', label='Total Loss', linewidth=2)
ax.plot(steps, metrics['data_loss'], 'g--', label='Data Loss', linewidth=2)
ax.plot(steps, metrics['reg_loss'], 'r--', label='Reg Loss', linewidth=2)
ax.set_xlabel('Training Step')
ax.set_ylabel('Loss')
ax.set_title('Training Progress')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
plot_training_metrics(metrics)
# %% [markdown]
# ## Final Visualization of Trained Field
# %%
def visualize_final_result(velocity_grid, source_points, target_points, config):
"""Visualize the final trained velocity field and deformation."""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
# Get grid positions
X, Y = create_position_grid(config.GRID_SIZE, config.GRID_EXTENT)
# Plot 1: Final velocity field
ax1.quiver(X, Y, velocity_grid[:, :, 0], velocity_grid[:, :, 1],
scale=1/config.ARROW_SCALE, alpha=0.6)
ax1.scatter(source_points[:, 0], source_points[:, 1], c='blue', s=30, alpha=0.7, label='Source')
ax1.scatter(target_points[:, 0], target_points[:, 1], c='red', s=30, alpha=0.7, label='Target')
ax1.set_title('Final Velocity Field')
ax1.set_aspect('equal')
ax1.set_xlim(-config.GRID_EXTENT, config.GRID_EXTENT)
ax1.set_ylim(-config.GRID_EXTENT, config.GRID_EXTENT)
ax1.legend()
ax1.grid(True, alpha=0.2)
# Plot 2: Flow visualization
# Create streamlines
x_stream = np.linspace(-config.GRID_EXTENT, config.GRID_EXTENT, 30)
y_stream = np.linspace(-config.GRID_EXTENT, config.GRID_EXTENT, 30)
X_stream, Y_stream = np.meshgrid(x_stream, y_stream)
# Interpolate velocity field to streamline grid
U = np.zeros_like(X_stream)
V = np.zeros_like(Y_stream)
for i in range(X_stream.shape[0]):
for j in range(X_stream.shape[1]):
point = jnp.array([X_stream[i, j], Y_stream[i, j]])
vel = interpolate_velocity(point, velocity_grid, config.GRID_EXTENT)
U[i, j] = vel[0]
V[i, j] = vel[1]
ax2.streamplot(x_stream, y_stream, U, V, color='gray', linewidth=1, density=1.5)
# Transform and plot shapes
transform_fn = partial(integrate_velocity_field,
n_steps=config.N_INTEGRATION_STEPS,
grid_extent=config.GRID_EXTENT)
transformed = vmap(lambda p: transform_fn(p, velocity_grid))(source_points)
ax2.scatter(source_points[:, 0], source_points[:, 1], c='blue', s=30, alpha=0.7, label='Source')
ax2.scatter(target_points[:, 0], target_points[:, 1], c='red', s=30, alpha=0.7, label='Target')
ax2.scatter(transformed[:, 0], transformed[:, 1], c='green', s=40, label='Transformed', edgecolors='black', linewidth=0.5)
ax2.set_title('Velocity Field Streamlines')
ax2.set_aspect('equal')
ax2.set_xlim(-config.GRID_EXTENT, config.GRID_EXTENT)
ax2.set_ylim(-config.GRID_EXTENT, config.GRID_EXTENT)
ax2.legend()
ax2.grid(True, alpha=0.2)
plt.tight_layout()
plt.show()
visualize_final_result(velocity_grid, source_shape, target_shape, Config)
# %% [markdown]
# ## Summary
#
# This notebook demonstrates direct optimization of a velocity field represented as a grid:
#
# - **No Neural Network**: The velocity field is directly parameterized as a grid of vectors
# - **Bilinear Interpolation**: Smooth velocity values between grid points
# - **Grid Deformation Visualization**: Shows how space is warped during optimization
# - **Efficient**: Typically fewer parameters than a neural network
#
# It shows how the optimization process discovers a deformation transforming the source shape into the target shape.
# %%
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment