Created
June 23, 2025 22:49
-
-
Save yberreby/4c8a8370293a7cd092bcfca3603d57d3 to your computer and use it in GitHub Desktop.
Optimization of a stationary velocity field with Adam in order to morph a shape into another. JAX-based interactive notebook, shared in `jupytext` format.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# --- | |
# jupyter: | |
# jupytext: | |
# formats: ipynb,py:percent | |
# text_representation: | |
# extension: .py | |
# format_name: percent | |
# format_version: '1.3' | |
# jupytext_version: 1.17.2 | |
# kernelspec: | |
# display_name: Python 3 (ipykernel) | |
# language: python | |
# name: python3 | |
# --- | |
# %% [markdown] | |
# # Direct Velocity Field Optimization | |
# | |
# Optimize a velocity field represented as a grid. | |
# %% | |
# %matplotlib widget | |
# %% | |
import jax | |
import jax.numpy as jnp | |
from jax import vmap, jit, grad, value_and_grad | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.animation import FuncAnimation | |
from IPython.display import HTML | |
from typing import Tuple, Optional | |
from jaxtyping import Array, Float, ScalarLike, jaxtyped | |
from beartype import beartype as typechecker | |
from functools import partial | |
import optax | |
# %% | |
class Config: | |
"""Central configuration for the experiment.""" | |
# Shapes | |
N_POINTS = 60 | |
CIRCLE_RADIUS = 0.8 | |
SQUARE_SIZE = 1.4 | |
# Velocity field grid | |
GRID_SIZE = 20 # NxN grid of velocity vectors | |
GRID_EXTENT = 3.0 # Grid covers [-X, X] x [-X, X] | |
# Integration | |
N_INTEGRATION_STEPS = 20 | |
# Training | |
LEARNING_RATE = 0.01 | |
N_STEPS = 200 | |
REG_WEIGHT = 0.1 | |
# Visualization | |
VIZ_EVERY = 1 # Update visualization every N steps | |
ARROW_SCALE = 0.3 | |
FIG_SIZE = (15, 5) | |
# Set random seed | |
key = jax.random.PRNGKey(42) | |
# %% [markdown] | |
# ## Shape Generation | |
# %% | |
@partial(jit, static_argnames=['n_points']) | |
def create_circle(n_points: int, radius: ScalarLike , center: Float[Array, "2"]) -> Float[Array, "n 2"]: | |
theta = jnp.linspace(0, 2 * jnp.pi, n_points, endpoint=False) | |
points = radius * jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1) | |
return points + center | |
@partial(jit, static_argnames=['n_points']) | |
def create_square(n_points: int, size: ScalarLike, center: Float[Array, "2"]) -> Float[Array, "n 2"]: | |
n_side = n_points // 4 | |
half_size = size / 2 | |
# Generate points for each side | |
t = jnp.linspace(-half_size, half_size, n_side, endpoint=False) | |
# Four sides of the square | |
sides = [ | |
jnp.stack([t, jnp.full_like(t, -half_size)], axis=1), # Bottom | |
jnp.stack([jnp.full_like(t, half_size), t], axis=1), # Right | |
jnp.stack([jnp.flip(t), jnp.full_like(t, half_size)], axis=1), # Top | |
jnp.stack([jnp.full_like(t, -half_size), jnp.flip(t)], axis=1) # Left | |
] | |
points = jnp.vstack(sides) | |
return points + center | |
# Create source and target shapes | |
source_shape = create_circle(Config.N_POINTS, Config.CIRCLE_RADIUS, center=jnp.zeros(2)) | |
target_shape = create_square(Config.N_POINTS, Config.SQUARE_SIZE, center=jnp.zeros(2)) | |
# %% [markdown] | |
# ## Grid-based Velocity Field | |
# %% | |
def create_velocity_grid(key: jax.random.PRNGKey, grid_size: int, scale: float = 0.1) -> Float[Array, "grid_size grid_size 2"]: | |
"""Randomly initialize a grid of velocity vectors.""" | |
return scale * jax.random.normal(key, (grid_size, grid_size, 2)) | |
def create_position_grid(grid_size: int, extent: float) -> Tuple[Float[Array, "grid_size grid_size"], Float[Array, "grid_size grid_size"]]: | |
"""Create the x,y coordinates for the velocity grid.""" | |
x = jnp.linspace(-extent, extent, grid_size) | |
y = jnp.linspace(-extent, extent, grid_size) | |
return jnp.meshgrid(x, y) | |
@jit | |
def interpolate_velocity( | |
point: Float[Array, "2"], | |
velocity_grid: Float[Array, "grid_size grid_size 2"], | |
grid_extent: float | |
) -> Float[Array, "2"]: | |
"""Bilinear interpolation of velocity at a point from the grid.""" | |
grid_size = velocity_grid.shape[0] | |
# Map point to grid coordinates [0, grid_size-1] | |
grid_x = (point[0] + grid_extent) / (2 * grid_extent) * (grid_size - 1) | |
grid_y = (point[1] + grid_extent) / (2 * grid_extent) * (grid_size - 1) | |
# Get integer indices and fractional parts | |
x0 = jnp.floor(grid_x).astype(jnp.int32) | |
y0 = jnp.floor(grid_y).astype(jnp.int32) | |
x1 = x0 + 1 | |
y1 = y0 + 1 | |
# Clamp to grid bounds | |
x0 = jnp.clip(x0, 0, grid_size - 1) | |
x1 = jnp.clip(x1, 0, grid_size - 1) | |
y0 = jnp.clip(y0, 0, grid_size - 1) | |
y1 = jnp.clip(y1, 0, grid_size - 1) | |
# Fractional parts | |
fx = grid_x - jnp.floor(grid_x) | |
fy = grid_y - jnp.floor(grid_y) | |
# Bilinear interpolation | |
v00 = velocity_grid[y0, x0] | |
v01 = velocity_grid[y0, x1] | |
v10 = velocity_grid[y1, x0] | |
v11 = velocity_grid[y1, x1] | |
v0 = (1 - fx) * v00 + fx * v01 | |
v1 = (1 - fx) * v10 + fx * v11 | |
v = (1 - fy) * v0 + fy * v1 | |
return v | |
# %% [markdown] | |
# ## ODE Integration | |
# %% | |
@partial(jit, static_argnames=['n_steps']) | |
def integrate_velocity_field( | |
point: Float[Array, "2"], | |
velocity_grid: Float[Array, "grid_size grid_size 2"], | |
n_steps: int, | |
grid_extent: float | |
) -> Float[Array, "2"]: | |
"""Integrate the velocity field ODE using forward Euler method.""" | |
dt = 1.0 / n_steps | |
def step(pos, _): | |
vel = interpolate_velocity(pos, velocity_grid, grid_extent) | |
new_pos = pos + vel * dt | |
return new_pos, pos | |
final_pos, trajectory = jax.lax.scan(step, point, jnp.arange(n_steps)) | |
return final_pos | |
# %% [markdown] | |
# ## Loss Functions | |
# %% | |
def chamfer_distance(points1: Float[Array, "n 2"], points2: Float[Array, "m 2"]) -> Float[Array, ""]: | |
"""Compute bidirectional Chamfer distance between two point sets.""" | |
# Compute pairwise squared distances | |
dists_sq = jnp.sum((points1[:, None, :] - points2[None, :, :]) ** 2, axis=2) | |
# Nearest neighbor distances in both directions | |
nn1_to_2 = jnp.min(dists_sq, axis=1).mean() | |
nn2_to_1 = jnp.min(dists_sq, axis=0).mean() | |
return nn1_to_2 + nn2_to_1 | |
def smoothness_regularization(velocity_grid: Float[Array, "grid_size grid_size 2"]) -> Float[Array, ""]: | |
"""Regularization to encourage smooth velocity fields.""" | |
# Compute differences between adjacent grid points | |
dx = velocity_grid[1:, :, :] - velocity_grid[:-1, :, :] | |
dy = velocity_grid[:, 1:, :] - velocity_grid[:, :-1, :] | |
# Penalize large gradients | |
return jnp.mean(dx**2) + jnp.mean(dy**2) | |
def create_loss_fn(source_points, target_points, config): | |
"""Create the full loss function.""" | |
@jit | |
def loss_fn(velocity_grid): | |
# Transform source points | |
transform_fn = partial(integrate_velocity_field, | |
n_steps=config.N_INTEGRATION_STEPS, | |
grid_extent=config.GRID_EXTENT) | |
transformed = vmap(lambda p: transform_fn(p, velocity_grid))(source_points) | |
# Data term: match target shape | |
data_loss = chamfer_distance(transformed, target_points) | |
# Regularization term: smooth velocity field | |
reg_loss = smoothness_regularization(velocity_grid) | |
total_loss = data_loss + config.REG_WEIGHT * reg_loss | |
return total_loss, (data_loss, reg_loss, transformed) | |
return loss_fn | |
# %% [markdown] | |
# ## Interactive Training Visualization | |
# %% | |
config = Config | |
source_points = source_shape | |
target_points = target_shape | |
# %% | |
# Initialize velocity grid | |
velocity_grid = create_velocity_grid(key, config.GRID_SIZE, scale=0.1) | |
# Create loss function and optimizer | |
loss_fn = create_loss_fn(source_points, target_points, config) | |
optimizer = optax.adam(config.LEARNING_RATE) | |
opt_state = optimizer.init(velocity_grid) | |
# Storage for metrics | |
metrics = {'loss': [], 'data_loss': [], 'reg_loss': []} | |
# Setup interactive figure | |
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=config.FIG_SIZE) | |
# Get grid positions | |
X, Y = create_position_grid(config.GRID_SIZE, config.GRID_EXTENT) | |
# Initialize plot elements | |
# Ax1: Velocity field | |
quiver = ax1.quiver(X, Y, velocity_grid[:, :, 0], velocity_grid[:, :, 1], | |
scale=1/config.ARROW_SCALE, alpha=0.6) | |
source_scatter1 = ax1.scatter(source_points[:, 0], source_points[:, 1], c='blue', s=20, alpha=0.5, label='Source') | |
target_scatter1 = ax1.scatter(target_points[:, 0], target_points[:, 1], c='red', s=20, alpha=0.5, label='Target') | |
ax1.set_title('Velocity Field') | |
ax1.set_aspect('equal') | |
ax1.set_xlim(-config.GRID_EXTENT, config.GRID_EXTENT) | |
ax1.set_ylim(-config.GRID_EXTENT, config.GRID_EXTENT) | |
ax1.legend() | |
# Ax2: Deformed grid - initialize with straight lines | |
viz_grid_size = 30 | |
x_fine = jnp.linspace(-config.GRID_EXTENT, config.GRID_EXTENT, viz_grid_size) | |
y_fine = jnp.linspace(-config.GRID_EXTENT, config.GRID_EXTENT, viz_grid_size) | |
# Store line objects | |
grid_lines = [] | |
# Vertical lines | |
for i, x in enumerate(x_fine[::3]): | |
line_points = jnp.stack([jnp.full_like(y_fine, x), y_fine], axis=1) | |
line, = ax2.plot(line_points[:, 0], line_points[:, 1], 'gray', alpha=0.3, linewidth=0.8) | |
grid_lines.append(('v', i, line)) | |
# Horizontal lines | |
for i, y in enumerate(y_fine[::3]): | |
line_points = jnp.stack([x_fine, jnp.full_like(x_fine, y)], axis=1) | |
line, = ax2.plot(line_points[:, 0], line_points[:, 1], 'gray', alpha=0.3, linewidth=0.8) | |
grid_lines.append(('h', i, line)) | |
target_scatter2 = ax2.scatter(target_points[:, 0], target_points[:, 1], c='red', s=20, alpha=0.5, label='Target') | |
transformed_scatter = ax2.scatter([], [], c='green', s=30, label='Transformed') | |
ax2.set_title('Deformed Grid') | |
ax2.set_aspect('equal') | |
ax2.set_xlim(-config.GRID_EXTENT, config.GRID_EXTENT) | |
ax2.set_ylim(-config.GRID_EXTENT, config.GRID_EXTENT) | |
ax2.legend() | |
# Ax3: Velocity magnitude | |
velocity_magnitude = jnp.sqrt(velocity_grid[:, :, 0]**2 + velocity_grid[:, :, 1]**2) | |
im = ax3.imshow(velocity_magnitude, extent=[-config.GRID_EXTENT, config.GRID_EXTENT, | |
-config.GRID_EXTENT, config.GRID_EXTENT], | |
origin='lower', cmap='viridis', vmin=0, vmax=2) | |
ax3.scatter(source_points[:, 0], source_points[:, 1], c='blue', s=20, alpha=0.5) | |
ax3.scatter(target_points[:, 0], target_points[:, 1], c='red', s=20, alpha=0.5) | |
ax3.set_title('Velocity Magnitude') | |
ax3.set_aspect('equal') | |
plt.colorbar(im, ax=ax3, label='|v|') | |
fig.suptitle('Training Progress') | |
plt.tight_layout() | |
# %% | |
# Train the model with interactive visualization | |
key, subkey = jax.random.split(key) | |
# Training loop with live updates | |
transform_fn = partial(integrate_velocity_field, | |
n_steps=config.N_INTEGRATION_STEPS, | |
grid_extent=config.GRID_EXTENT) | |
for step in range(config.N_STEPS): | |
# Compute loss and gradients | |
(loss, (data_loss, reg_loss, transformed)), grads = value_and_grad(loss_fn, has_aux=True)(velocity_grid) | |
# Update parameters | |
updates, opt_state = optimizer.update(grads, opt_state) | |
velocity_grid = optax.apply_updates(velocity_grid, updates) | |
# Record metrics | |
metrics['loss'].append(float(loss)) | |
metrics['data_loss'].append(float(data_loss)) | |
metrics['reg_loss'].append(float(reg_loss)) | |
# Update visualization | |
if step % config.VIZ_EVERY == 0 or step == config.N_STEPS - 1: | |
# Update velocity field arrows | |
quiver.set_UVC(velocity_grid[:, :, 0], velocity_grid[:, :, 1]) | |
# Update deformed grid | |
for line_type, idx, line in grid_lines: | |
if line_type == 'v': | |
x = x_fine[::3][idx] | |
line_points = jnp.stack([jnp.full_like(y_fine, x), y_fine], axis=1) | |
else: # 'h' | |
y = y_fine[::3][idx] | |
line_points = jnp.stack([x_fine, jnp.full_like(x_fine, y)], axis=1) | |
transformed_line = vmap(lambda p: transform_fn(p, velocity_grid))(line_points) | |
line.set_data(transformed_line[:, 0], transformed_line[:, 1]) | |
# Update transformed points | |
transformed_scatter.set_offsets(transformed) | |
# Update velocity magnitude | |
velocity_magnitude = jnp.sqrt(velocity_grid[:, :, 0]**2 + velocity_grid[:, :, 1]**2) | |
im.set_data(velocity_magnitude) | |
im.set_clim(vmin=0, vmax=velocity_magnitude.max()) | |
# Update titles | |
ax1.set_title(f'Velocity Field (Step {step})') | |
ax2.set_title(f'Deformed Grid (Step {step})') | |
ax3.set_title(f'Velocity Magnitude (Step {step})') | |
fig.suptitle(f'Step {step} | Loss: {loss:.4f} | Data: {data_loss:.4f} | Reg: {reg_loss:.4f}') | |
# Draw updates | |
fig.canvas.draw() | |
if step % (config.VIZ_EVERY * 5) == 0: | |
print(f"Step {step:3d} | Loss: {loss:.4f} | Data: {data_loss:.4f} | Reg: {reg_loss:.4f}") | |
# %% [markdown] | |
# ## Plot Training Metrics | |
# %% | |
def plot_training_metrics(metrics): | |
"""Plot training loss curves.""" | |
fig, ax = plt.subplots(figsize=(8, 5)) | |
steps = np.arange(len(metrics['loss'])) | |
ax.plot(steps, metrics['loss'], 'b-', label='Total Loss', linewidth=2) | |
ax.plot(steps, metrics['data_loss'], 'g--', label='Data Loss', linewidth=2) | |
ax.plot(steps, metrics['reg_loss'], 'r--', label='Reg Loss', linewidth=2) | |
ax.set_xlabel('Training Step') | |
ax.set_ylabel('Loss') | |
ax.set_title('Training Progress') | |
ax.legend() | |
ax.grid(True, alpha=0.3) | |
plt.tight_layout() | |
plt.show() | |
plot_training_metrics(metrics) | |
# %% [markdown] | |
# ## Final Visualization of Trained Field | |
# %% | |
def visualize_final_result(velocity_grid, source_points, target_points, config): | |
"""Visualize the final trained velocity field and deformation.""" | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) | |
# Get grid positions | |
X, Y = create_position_grid(config.GRID_SIZE, config.GRID_EXTENT) | |
# Plot 1: Final velocity field | |
ax1.quiver(X, Y, velocity_grid[:, :, 0], velocity_grid[:, :, 1], | |
scale=1/config.ARROW_SCALE, alpha=0.6) | |
ax1.scatter(source_points[:, 0], source_points[:, 1], c='blue', s=30, alpha=0.7, label='Source') | |
ax1.scatter(target_points[:, 0], target_points[:, 1], c='red', s=30, alpha=0.7, label='Target') | |
ax1.set_title('Final Velocity Field') | |
ax1.set_aspect('equal') | |
ax1.set_xlim(-config.GRID_EXTENT, config.GRID_EXTENT) | |
ax1.set_ylim(-config.GRID_EXTENT, config.GRID_EXTENT) | |
ax1.legend() | |
ax1.grid(True, alpha=0.2) | |
# Plot 2: Flow visualization | |
# Create streamlines | |
x_stream = np.linspace(-config.GRID_EXTENT, config.GRID_EXTENT, 30) | |
y_stream = np.linspace(-config.GRID_EXTENT, config.GRID_EXTENT, 30) | |
X_stream, Y_stream = np.meshgrid(x_stream, y_stream) | |
# Interpolate velocity field to streamline grid | |
U = np.zeros_like(X_stream) | |
V = np.zeros_like(Y_stream) | |
for i in range(X_stream.shape[0]): | |
for j in range(X_stream.shape[1]): | |
point = jnp.array([X_stream[i, j], Y_stream[i, j]]) | |
vel = interpolate_velocity(point, velocity_grid, config.GRID_EXTENT) | |
U[i, j] = vel[0] | |
V[i, j] = vel[1] | |
ax2.streamplot(x_stream, y_stream, U, V, color='gray', linewidth=1, density=1.5) | |
# Transform and plot shapes | |
transform_fn = partial(integrate_velocity_field, | |
n_steps=config.N_INTEGRATION_STEPS, | |
grid_extent=config.GRID_EXTENT) | |
transformed = vmap(lambda p: transform_fn(p, velocity_grid))(source_points) | |
ax2.scatter(source_points[:, 0], source_points[:, 1], c='blue', s=30, alpha=0.7, label='Source') | |
ax2.scatter(target_points[:, 0], target_points[:, 1], c='red', s=30, alpha=0.7, label='Target') | |
ax2.scatter(transformed[:, 0], transformed[:, 1], c='green', s=40, label='Transformed', edgecolors='black', linewidth=0.5) | |
ax2.set_title('Velocity Field Streamlines') | |
ax2.set_aspect('equal') | |
ax2.set_xlim(-config.GRID_EXTENT, config.GRID_EXTENT) | |
ax2.set_ylim(-config.GRID_EXTENT, config.GRID_EXTENT) | |
ax2.legend() | |
ax2.grid(True, alpha=0.2) | |
plt.tight_layout() | |
plt.show() | |
visualize_final_result(velocity_grid, source_shape, target_shape, Config) | |
# %% [markdown] | |
# ## Summary | |
# | |
# This notebook demonstrates direct optimization of a velocity field represented as a grid: | |
# | |
# - **No Neural Network**: The velocity field is directly parameterized as a grid of vectors | |
# - **Bilinear Interpolation**: Smooth velocity values between grid points | |
# - **Grid Deformation Visualization**: Shows how space is warped during optimization | |
# - **Efficient**: Typically fewer parameters than a neural network | |
# | |
# It shows how the optimization process discovers a deformation transforming the source shape into the target shape. | |
# %% | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment