Skip to content

Instantly share code, notes, and snippets.

@yberreby
Last active April 19, 2025 00:06
Show Gist options
  • Save yberreby/ebfefe1c087d68f6c280e9f6e120d4fd to your computer and use it in GitHub Desktop.
Save yberreby/ebfefe1c087d68f6c280e9f6e120d4fd to your computer and use it in GitHub Desktop.
Cahn-Hilliard - Run this for visually-interesting non-linear ODE behavior, in an animated plot.
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "jax==0.5.0",
# "matplotlib>=3.10.1",
# "pyqt6>=6.9.0", # for matplotlib gui
# ]
# ///
from pathlib import Path
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
# -- 1. Wikipedia formulation ------------------------------------------------
# On Wikipedia, the Cahn–Hilliard equation is presented as:
# ∂c/∂t = D ∇² (c³ - c - γ ∇² c),
# where:
# • c(x,t) ∈ [-1, 1] is the concentration (±1 indicate pure domains),
# • D with units [Length²/Time] is a diffusion coefficient,
# • γ controls interface width: √γ gives the transition-layer length,
# • μ = c³ - c - γ ∇² c is identified as the chemical potential.
# See: Wikipedia "Cahn–Hilliard equation": [en.wikipedia.org](https://en.wikipedia.org/wiki/Cahn%E2%80%93Hilliard_equation)).
# We choose D = 1 for simplicity, rename c → u, and set γ = ε².
# Thus our PDE becomes:
# ∂u/∂t = ∇² (u³ - u - ε² ∇² u)
# -- 2. Free-energy functional & Lyapunov property -------------------------------
# Wikipedia also defines the free energy:
# F[u] = ∫ [ (u² - 1)²/4 + (ε²/2) |∇u|² ] dx,
# and proves:
# dF/dt = - ∫ |∇μ|² dx ≤ 0,
# guaranteeing monotonic energy decay and hence segregation into domains.
# (dF/dt formula sourced from the same article).
# -- 3. Domain & spectral setup ------------------------------------------------
# We simulate on a periodic square [0,L]^2 with N×N grid points.
N, L = 512, 9.0
# Discrete Fourier wave numbers k = (2π/L)*fft_freq(N)
k = jnp.fft.fftfreq(N, d=L/N) * (2*jnp.pi)
KX, KY = jnp.meshgrid(k, k, indexing="ij")
K2, K4 = KX**2 + KY**2, (KX**2 + KY**2)**2
# -- 4. Time-stepping: semi-implicit (IMEX) ------------------------------------
# We discretize in time Δt, treating the stiff linear term (ε² ∇⁴ u) implicitly
# and the nonlinear term (∇²(u³ - u)) explicitly:
# (1 + Δt ε² K⁴) ûⁿ⁺¹ = ûⁿ - Δt K² 𝓕[u³ - u]
# giving a stable first-order scheme for reasonable Δt.
# See Chen & Shen (1998) for analysis of this Fourier-spectral IMEX approach.
ε, Δt = 0.02, 5e-5
DENOM = 1 + Δt * ε**2 * K4
@jax.jit
def spectral_step(u: jnp.ndarray) -> jnp.ndarray:
"""
One time step (semi-implicit IMEX):
1) û = FFT[u]
2) ŵ = FFT[u³ - u] # explicit nonlinear term
3) û_new = (û - Δt * K2 * ŵ) / (1 + Δt ε² K4)
4) u_new = IFFT[û_new].real
"""
u_hat = jnp.fft.fft2(u)
w_hat = jnp.fft.fft2(u**3 - u)
u_hat_new = (u_hat - Δt * K2 * w_hat) / DENOM
# clamp to avoid numerical overshoot
return jnp.clip(jnp.fft.ifft2(u_hat_new).real, -1.5, 1.5)
# -- 5. Time integration via JAX scan ------------------------------------------
STEPS = 10 # sub-steps per animation frame
@jax.jit
def evolve(u: jnp.ndarray) -> jnp.ndarray:
# jax.lax.scan compiles a loop over STEPS without Python overhead
def _body(carry, _): return spectral_step(carry), None
u_final, _ = jax.lax.scan(_body, u, None, length=STEPS)
return u_final
# -- 6. Initial condition & diagnostics ----------------------------------------
# White noise around zero mean: mean(u)=0 ⇒ equal phase fractions
key = jax.random.PRNGKey(0)
u = 0.2 * jax.random.normal(key, (N, N))
mass0 = float(u.mean()) # conserved quantity
# -- 7. Visualization & animation ----------------------------------------------
fig, ax = plt.subplots(figsize=(6,6))
im = ax.imshow(np.array(u), cmap='viridis', origin='lower',
vmin=-1, vmax=1, interpolation='bilinear')
ax.set_axis_off()
def frame(i: int):
global u
u = evolve(u)
u_np = np.asarray(u)
# mass conservation check
mass_err = abs(u_np.mean() - mass0)
# domain growth indicated by variance
var = u_np.var()
print(f"[Frame {i+1}] var={var:.4f}, mass_err={mass_err:.2e}")
im.set_data(u_np)
ax.set_title(f"t ≃ {(i+1)*STEPS*Δt:.2f}, var={var:.3f}")
return [im]
ani = FuncAnimation(fig, frame, frames=240, interval=30, blit=True, repeat=False)
output_filename = "cahn_hilliard_simulation.mp4"
output_path = Path(output_filename).resolve()
ani.save(output_path, fps=30, extra_args=['-vcodec', 'libx264'])
print(f"Animation saved to: {output_path}")
This file has been truncated, but you can view the full file.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment