Last active
June 21, 2025 15:39
-
-
Save yberreby/ad1ad0b6d26a9f5322e9b25b78d0b0a1 to your computer and use it in GitHub Desktop.
End-to-end differentiable maze example in JAX.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env -S uv run --script | |
# /// script | |
# requires-python = ">=3.10" | |
# dependencies = [ | |
# "pyqt6", # For matplotlib backend | |
# "numpy", | |
# "optax>=0.2.5", | |
# "matplotlib", | |
# "jax[cuda12]==0.5.2", # Change for CPU. | |
# "jaxtyping>=0.3.2", | |
# "beartype>=0.21.0" | |
# ] | |
# /// | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import optax | |
import matplotlib.pyplot as plt | |
from jaxtyping import Array, Float, Bool, jaxtyped | |
from beartype import beartype as typechecker | |
from typing import Tuple | |
import time | |
#--- Indexing helpers ---------------------------------------------------------- | |
def coords_to_index(i: int, j: int, W: int) -> int: | |
"""(i, j) -> flat index in row-major order.""" | |
return i * W + j | |
def index_to_coords(k: int, W: int) -> Tuple[int, int]: | |
"""Inverse of coords_to_index.""" | |
return divmod(k, W) | |
#--- Maze logic: walls and transitions ----------------------------------------- | |
@jaxtyped(typechecker=typechecker) | |
def make_wall_tensor(raw_maze: Bool[Array, "H W"]) -> Bool[Array, "H W 4"]: | |
""" | |
Returns walls[H,W,4]: walls[i,j,a] is True if action a from (i,j) is blocked. | |
a=0:↑, 1:→, 2:↓, 3:←. Blocked by boundary or obstacle. | |
""" | |
H, W = raw_maze.shape | |
walls = np.zeros((H, W, 4), dtype=bool) | |
walls[0, :, 0] = True | |
walls[:, W-1, 1] = True | |
walls[H-1, :, 2] = True | |
walls[:, 0, 3] = True | |
for i in range(H): | |
for j in range(W): | |
if raw_maze[i, j]: | |
walls[i, j, :] = True | |
if i > 0: walls[i-1, j, 2] = True | |
if i < H-1: walls[i+1, j, 0] = True | |
if j > 0: walls[i, j-1, 1] = True | |
if j < W-1: walls[i, j+1, 3] = True | |
return jnp.asarray(walls) | |
@jaxtyped(typechecker=typechecker) | |
def transition_matrices(walls: Bool[Array, "H W 4"]) -> Float[Array, "4 N N"]: | |
""" | |
Returns dense transition matrices T[a, from, to]. | |
Each row is one-hot: if action a from cell n is blocked, stays in place. | |
""" | |
H, W, _ = walls.shape | |
N = H * W | |
T = np.zeros((4, N, N), dtype=np.float32) | |
for i in range(H): | |
for j in range(W): | |
k = coords_to_index(i, j, W) | |
neigh = [(i-1, j), (i, j+1), (i+1, j), (i, j-1)] | |
for a, (ni, nj) in enumerate(neigh): | |
blocked = (ni < 0 or nj < 0 or ni >= H or nj >= W) or walls[i, j, a] | |
k_to = k if blocked else coords_to_index(ni, nj, W) | |
T[a, k, k_to] = 1.0 | |
return jnp.asarray(T) | |
#--- Policy / step logic ------------------------------------------------------- | |
@jaxtyped(typechecker=typechecker) | |
def feasibility_mask(state: Float[Array, "N"], walls: Bool[Array, "H W 4"]) -> Bool[Array, "H W 4"]: | |
""" | |
Mask[i,j,a]: True iff agent *could* be at (i,j) and a is allowed there. | |
""" | |
H, W, _ = walls.shape | |
occupied = (state > 0).reshape(H, W) | |
return jnp.logical_and(~walls, occupied[:, :, None]) | |
@jaxtyped(typechecker=typechecker) | |
def masked_softmax(logits: Float[Array, "4"], mask: Bool[Array, "H W 4"]) -> Float[Array, "H W 4"]: | |
"""Softmax along actions, -inf on masked-out moves.""" | |
logits_b = logits.reshape(1, 1, 4) | |
logits_m = jnp.where(mask, logits_b, -1e30) | |
return jax.nn.softmax(logits_m, axis=2) | |
@jaxtyped(typechecker=typechecker) | |
@jax.jit | |
def state_step( | |
state: Float[Array, "N"], | |
Pi: Float[Array, "H W 4"], | |
T: Float[Array, "4 N N"] | |
) -> Float[Array, "N"]: | |
""" | |
One step of differentiable state update: | |
- Pi: per-cell action probs (must be masked+normalized). | |
- T: transition matrices. | |
""" | |
H, W, _ = Pi.shape | |
N = H * W | |
weighted = state[:, None] * Pi.reshape(N, 4) | |
# einsum: ank, na -> k : for each possible dest k, sum over (action, from) | |
return jnp.einsum('ank,na->k', T, weighted) | |
@jax.jit | |
def rollout( | |
init_state: Float[Array, "N"], | |
logits_seq: Float[Array, "T 4"], | |
*, | |
T: Float[Array, "4 N N"], | |
walls: Bool[Array, "H W 4"] | |
) -> Float[Array, "T1 N"]: | |
""" | |
Unroll the state for T steps. At each step, mask softmax ensures feasibility. | |
Returns full state trajectory. | |
""" | |
def body(state, logits_t): | |
Pi = masked_softmax(logits_t, feasibility_mask(state, walls)) | |
next_state = state_step(state, Pi, T) | |
return next_state, next_state | |
_, hist = jax.lax.scan(body, init_state, logits_seq) | |
return jnp.vstack([init_state, hist]) # shape (T+1, N) | |
@jaxtyped(typechecker=typechecker) | |
def terminal_loss(states: Float[Array, "T1 N"], goal_idx: int) -> Float[Array, ""]: | |
"""Loss = negative probability at goal at final step.""" | |
return -states[-1, goal_idx] | |
@jaxtyped(typechecker=typechecker) | |
@jax.jit | |
def greedy_path( | |
logits_seq: Float[Array, "T 4"], | |
*, | |
start_idx: int, | |
T: Float[Array, "4 N N"], | |
walls: Bool[Array, "H W 4"], | |
W: int | |
) -> Array: | |
""" | |
Simulates a deterministic agent, always picking argmax action at each cell. | |
Returns sequence of visited indices (T+1,). | |
""" | |
N = T.shape[1] | |
def step(k, logits_t): | |
mask = feasibility_mask(jnp.zeros(N).at[k].set(1.0), walls) | |
Pi = masked_softmax(logits_t, mask) | |
i, j = index_to_coords(k, W) | |
a = jnp.argmax(Pi[i, j]) | |
k_next = jnp.argmax(T[a, k]) | |
return k_next, k_next | |
_, traj = jax.lax.scan(step, start_idx, logits_seq) | |
return jnp.hstack([jnp.array([start_idx]), traj]) | |
#--- Visualisation ------------------------------------------------------------- | |
class Visualiser: | |
""" | |
Subplot-0: maze (walls, start/goal, greedy path, live state-prob heatmap) | |
Subplot-1: policy (π(a|t) vs t) | |
Subplot-2: loss curve | |
""" | |
def __init__(self, | |
maze: np.ndarray, | |
start: Tuple[int, int], | |
goal: Tuple[int, int], | |
horizon: int): | |
import matplotlib.gridspec as gridspec | |
self.H, self.W = maze.shape | |
self.losses: list[float] = [] | |
plt.ion() | |
self.fig = plt.figure(figsize=(14, 8)) | |
gs = gridspec.GridSpec( | |
2, 2, | |
height_ratios=[2, 1], | |
width_ratios=[1, 1], | |
hspace=0.3, | |
wspace=0.25, | |
) | |
# ----------------- MAZE PANEL ----------------- | |
self.ax_maze = self.fig.add_subplot(gs[0, 0]) | |
# background: free = white, wall = black | |
self.ax_maze.imshow(1 - maze, cmap="gray", vmin=0, vmax=1) | |
self.ax_maze.set_xticks([]); self.ax_maze.set_yticks([]) | |
# start / goal marks | |
self.ax_maze.text(start[1], start[0], "S", | |
c="lime", ha="center", va="center", | |
fontsize=14, weight="bold") | |
self.ax_maze.text(goal[1], goal[0], "G", | |
c="red", ha="center", va="center", | |
fontsize=14, weight="bold") | |
# probability-mass overlay (updated every call) | |
self.state_img = self.ax_maze.imshow(np.zeros_like(maze, dtype=float), | |
cmap="plasma", | |
alpha=0.55, | |
vmin=0.0, vmax=1.0, | |
interpolation="nearest") | |
# greedy-path line handle | |
self.path_line = None | |
# ----------------- POLICY PANEL ----------------- | |
self.ax_policy = self.fig.add_subplot(gs[0, 1]) | |
self.policy_img = self.ax_policy.imshow( | |
np.zeros((4, horizon)), | |
cmap="viridis", | |
vmin=0, vmax=1, | |
aspect="auto", | |
) | |
self.ax_policy.set_yticks(range(4)) | |
self.ax_policy.set_yticklabels(["↑", "→", "↓", "←"]) | |
self.ax_policy.set_xlabel("Timestep") | |
self.fig.colorbar(self.policy_img, | |
ax=self.ax_policy, | |
label="Action prob") | |
# ----------------- LOSS PANEL ----------------- | |
self.ax_loss = self.fig.add_subplot(gs[1, :]) | |
self.loss_line, = self.ax_loss.plot([], []) | |
self.ax_loss.set_xlabel("Optimiser step") | |
self.ax_loss.set_ylabel("Loss") | |
self.ax_loss.grid(True, ls=":") | |
# -------------------------------------------------------------------------- | |
def update(self, | |
logits_seq: jnp.ndarray, | |
loss: float, | |
path: Array, | |
state_probs: Float[Array, "H W"]): | |
"""Redraw all live figures.""" | |
# policy-heatmap | |
self.policy_img.set_data(jax.nn.softmax(logits_seq, axis=1).T) | |
# probability-mass overlay | |
self.state_img.set_data(state_probs) | |
self.state_img.set_clim(vmin=0.0, vmax=float(state_probs.max() + 1e-9)) | |
# greedy path | |
path_coords = np.array([index_to_coords(k, self.W) for k in path]) | |
y, x = path_coords[:, 0], path_coords[:, 1] | |
if self.path_line is not None: | |
self.path_line.remove() | |
self.path_line, = self.ax_maze.plot(x, y, "r-o", lw=2, ms=4) | |
self.ax_maze.set_title(f"Greedy path (unique cells ={len(np.unique(path))})") | |
# loss curve | |
self.losses.append(loss) | |
self.loss_line.set_data(range(len(self.losses)), self.losses) | |
self.ax_loss.relim(); self.ax_loss.autoscale_view() | |
# refresh | |
self.fig.canvas.draw_idle() | |
plt.pause(0.001) | |
#--- Main ---------------------------------------------------------------------- | |
RAW_MAZE = np.array([ | |
[0,0,0,0,1,0,0], | |
[0,1,1,0,1,0,0], | |
[0,0,0,0,0,0,0], | |
[1,1,0,1,1,1,0], | |
[0,0,0,0,0,1,0], | |
[0,1,1,1,0,0,0], | |
[0,0,0,0,0,1,0], | |
], dtype=bool) | |
START = (0,0) | |
GOAL = (6,6) | |
T_HORIZON = 16 | |
LR = 0.01 | |
STEPS = 200 | |
STEP_INTERVAL = 0.05 # Set to >0 (seconds) to slow down optimisation steps | |
walls = make_wall_tensor(jnp.asarray(RAW_MAZE)) | |
T = transition_matrices(walls) | |
H, W = RAW_MAZE.shape; N = H * W | |
start_idx = coords_to_index(*START, W) | |
goal_idx = coords_to_index(*GOAL, W) | |
init_state = jnp.zeros(N).at[start_idx].set(1.0) | |
key = jax.random.PRNGKey(0) | |
logits_seq = jax.random.normal(key, (T_HORIZON, 4)) | |
opt = optax.adam(LR) | |
opt_state = opt.init(logits_seq) | |
vis = Visualiser(RAW_MAZE, START, GOAL, T_HORIZON) | |
@jax.jit | |
def loss_and_grad(params): | |
states = rollout(init_state, params, T=T, walls=walls) | |
return terminal_loss(states, goal_idx) | |
@jax.jit | |
def train_step(params, opt_state): | |
(loss_val, path_idx), grads = jax.value_and_grad( | |
lambda p: (loss_and_grad(p), greedy_path(p, start_idx=start_idx, T=T, walls=walls, W=W)), | |
has_aux=True)(params) | |
updates, opt_state = opt.update(grads, opt_state) | |
params = optax.apply_updates(params, updates) | |
return params, opt_state, loss_val, path_idx | |
print("Optimising… (close window to abort)") | |
for step in range(STEPS): | |
start_time = time.time() | |
logits_seq, opt_state, loss_value, path = train_step(logits_seq, opt_state) | |
# Abort cleanly if the window was closed. | |
if not plt.fignum_exists(vis.fig.number): | |
break | |
states = rollout(init_state, logits_seq, T=T, walls=walls) | |
final_probs = states[-1].reshape(H, W) # (H, W) | |
vis.update(logits_seq, float(loss_value), path, final_probs) | |
if STEP_INTERVAL > 0.0: | |
elapsed = time.time() - start_time | |
to_sleep = STEP_INTERVAL - elapsed | |
if to_sleep > 0: | |
time.sleep(to_sleep) | |
print("Done.") | |
plt.ioff(); plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment