Last active
April 19, 2025 00:06
-
-
Save yberreby/ebfefe1c087d68f6c280e9f6e120d4fd to your computer and use it in GitHub Desktop.
Cahn-Hilliard - Run this for visually-interesting non-linear ODE behavior, in an animated plot.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env -S uv run --script | |
# /// script | |
# requires-python = ">=3.12" | |
# dependencies = [ | |
# "jax==0.5.0", | |
# "matplotlib>=3.10.1", | |
# "pyqt6>=6.9.0", # for matplotlib gui | |
# ] | |
# /// | |
from pathlib import Path | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from matplotlib.animation import FuncAnimation | |
# -- 1. Wikipedia formulation ------------------------------------------------ | |
# On Wikipedia, the Cahn–Hilliard equation is presented as: | |
# ∂c/∂t = D ∇² (c³ - c - γ ∇² c), | |
# where: | |
# • c(x,t) ∈ [-1, 1] is the concentration (±1 indicate pure domains), | |
# • D with units [Length²/Time] is a diffusion coefficient, | |
# • γ controls interface width: √γ gives the transition-layer length, | |
# • μ = c³ - c - γ ∇² c is identified as the chemical potential. | |
# See: Wikipedia "Cahn–Hilliard equation": [en.wikipedia.org](https://en.wikipedia.org/wiki/Cahn%E2%80%93Hilliard_equation)). | |
# We choose D = 1 for simplicity, rename c → u, and set γ = ε². | |
# Thus our PDE becomes: | |
# ∂u/∂t = ∇² (u³ - u - ε² ∇² u) | |
# -- 2. Free-energy functional & Lyapunov property ------------------------------- | |
# Wikipedia also defines the free energy: | |
# F[u] = ∫ [ (u² - 1)²/4 + (ε²/2) |∇u|² ] dx, | |
# and proves: | |
# dF/dt = - ∫ |∇μ|² dx ≤ 0, | |
# guaranteeing monotonic energy decay and hence segregation into domains. | |
# (dF/dt formula sourced from the same article). | |
# -- 3. Domain & spectral setup ------------------------------------------------ | |
# We simulate on a periodic square [0,L]^2 with N×N grid points. | |
N, L = 512, 9.0 | |
# Discrete Fourier wave numbers k = (2π/L)*fft_freq(N) | |
k = jnp.fft.fftfreq(N, d=L/N) * (2*jnp.pi) | |
KX, KY = jnp.meshgrid(k, k, indexing="ij") | |
K2, K4 = KX**2 + KY**2, (KX**2 + KY**2)**2 | |
# -- 4. Time-stepping: semi-implicit (IMEX) ------------------------------------ | |
# We discretize in time Δt, treating the stiff linear term (ε² ∇⁴ u) implicitly | |
# and the nonlinear term (∇²(u³ - u)) explicitly: | |
# (1 + Δt ε² K⁴) ûⁿ⁺¹ = ûⁿ - Δt K² 𝓕[u³ - u] | |
# giving a stable first-order scheme for reasonable Δt. | |
# See Chen & Shen (1998) for analysis of this Fourier-spectral IMEX approach. | |
ε, Δt = 0.02, 5e-5 | |
DENOM = 1 + Δt * ε**2 * K4 | |
@jax.jit | |
def spectral_step(u: jnp.ndarray) -> jnp.ndarray: | |
""" | |
One time step (semi-implicit IMEX): | |
1) û = FFT[u] | |
2) ŵ = FFT[u³ - u] # explicit nonlinear term | |
3) û_new = (û - Δt * K2 * ŵ) / (1 + Δt ε² K4) | |
4) u_new = IFFT[û_new].real | |
""" | |
u_hat = jnp.fft.fft2(u) | |
w_hat = jnp.fft.fft2(u**3 - u) | |
u_hat_new = (u_hat - Δt * K2 * w_hat) / DENOM | |
# clamp to avoid numerical overshoot | |
return jnp.clip(jnp.fft.ifft2(u_hat_new).real, -1.5, 1.5) | |
# -- 5. Time integration via JAX scan ------------------------------------------ | |
STEPS = 10 # sub-steps per animation frame | |
@jax.jit | |
def evolve(u: jnp.ndarray) -> jnp.ndarray: | |
# jax.lax.scan compiles a loop over STEPS without Python overhead | |
def _body(carry, _): return spectral_step(carry), None | |
u_final, _ = jax.lax.scan(_body, u, None, length=STEPS) | |
return u_final | |
# -- 6. Initial condition & diagnostics ---------------------------------------- | |
# White noise around zero mean: mean(u)=0 ⇒ equal phase fractions | |
key = jax.random.PRNGKey(0) | |
u = 0.2 * jax.random.normal(key, (N, N)) | |
mass0 = float(u.mean()) # conserved quantity | |
# -- 7. Visualization & animation ---------------------------------------------- | |
fig, ax = plt.subplots(figsize=(6,6)) | |
im = ax.imshow(np.array(u), cmap='viridis', origin='lower', | |
vmin=-1, vmax=1, interpolation='bilinear') | |
ax.set_axis_off() | |
def frame(i: int): | |
global u | |
u = evolve(u) | |
u_np = np.asarray(u) | |
# mass conservation check | |
mass_err = abs(u_np.mean() - mass0) | |
# domain growth indicated by variance | |
var = u_np.var() | |
print(f"[Frame {i+1}] var={var:.4f}, mass_err={mass_err:.2e}") | |
im.set_data(u_np) | |
ax.set_title(f"t ≃ {(i+1)*STEPS*Δt:.2f}, var={var:.3f}") | |
return [im] | |
ani = FuncAnimation(fig, frame, frames=240, interval=30, blit=True, repeat=False) | |
output_filename = "cahn_hilliard_simulation.mp4" | |
output_path = Path(output_filename).resolve() | |
ani.save(output_path, fps=30, extra_args=['-vcodec', 'libx264']) | |
print(f"Animation saved to: {output_path}") |
This file has been truncated, but you can view the full file.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment