Skip to content

Instantly share code, notes, and snippets.

@yberreby
Last active June 21, 2025 15:39
Show Gist options
  • Save yberreby/ad1ad0b6d26a9f5322e9b25b78d0b0a1 to your computer and use it in GitHub Desktop.
Save yberreby/ad1ad0b6d26a9f5322e9b25b78d0b0a1 to your computer and use it in GitHub Desktop.
End-to-end differentiable maze example in JAX.
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "pyqt6", # For matplotlib backend
# "numpy",
# "optax>=0.2.5",
# "matplotlib",
# "jax[cuda12]==0.5.2", # Change for CPU.
# "jaxtyping>=0.3.2",
# "beartype>=0.21.0"
# ]
# ///
import jax
import jax.numpy as jnp
import numpy as np
import optax
import matplotlib.pyplot as plt
from jaxtyping import Array, Float, Bool, jaxtyped
from beartype import beartype as typechecker
from typing import Tuple
import time
#--- Indexing helpers ----------------------------------------------------------
def coords_to_index(i: int, j: int, W: int) -> int:
"""(i, j) -> flat index in row-major order."""
return i * W + j
def index_to_coords(k: int, W: int) -> Tuple[int, int]:
"""Inverse of coords_to_index."""
return divmod(k, W)
#--- Maze logic: walls and transitions -----------------------------------------
@jaxtyped(typechecker=typechecker)
def make_wall_tensor(raw_maze: Bool[Array, "H W"]) -> Bool[Array, "H W 4"]:
"""
Returns walls[H,W,4]: walls[i,j,a] is True if action a from (i,j) is blocked.
a=0:↑, 1:→, 2:↓, 3:←. Blocked by boundary or obstacle.
"""
H, W = raw_maze.shape
walls = np.zeros((H, W, 4), dtype=bool)
walls[0, :, 0] = True
walls[:, W-1, 1] = True
walls[H-1, :, 2] = True
walls[:, 0, 3] = True
for i in range(H):
for j in range(W):
if raw_maze[i, j]:
walls[i, j, :] = True
if i > 0: walls[i-1, j, 2] = True
if i < H-1: walls[i+1, j, 0] = True
if j > 0: walls[i, j-1, 1] = True
if j < W-1: walls[i, j+1, 3] = True
return jnp.asarray(walls)
@jaxtyped(typechecker=typechecker)
def transition_matrices(walls: Bool[Array, "H W 4"]) -> Float[Array, "4 N N"]:
"""
Returns dense transition matrices T[a, from, to].
Each row is one-hot: if action a from cell n is blocked, stays in place.
"""
H, W, _ = walls.shape
N = H * W
T = np.zeros((4, N, N), dtype=np.float32)
for i in range(H):
for j in range(W):
k = coords_to_index(i, j, W)
neigh = [(i-1, j), (i, j+1), (i+1, j), (i, j-1)]
for a, (ni, nj) in enumerate(neigh):
blocked = (ni < 0 or nj < 0 or ni >= H or nj >= W) or walls[i, j, a]
k_to = k if blocked else coords_to_index(ni, nj, W)
T[a, k, k_to] = 1.0
return jnp.asarray(T)
#--- Policy / step logic -------------------------------------------------------
@jaxtyped(typechecker=typechecker)
def feasibility_mask(state: Float[Array, "N"], walls: Bool[Array, "H W 4"]) -> Bool[Array, "H W 4"]:
"""
Mask[i,j,a]: True iff agent *could* be at (i,j) and a is allowed there.
"""
H, W, _ = walls.shape
occupied = (state > 0).reshape(H, W)
return jnp.logical_and(~walls, occupied[:, :, None])
@jaxtyped(typechecker=typechecker)
def masked_softmax(logits: Float[Array, "4"], mask: Bool[Array, "H W 4"]) -> Float[Array, "H W 4"]:
"""Softmax along actions, -inf on masked-out moves."""
logits_b = logits.reshape(1, 1, 4)
logits_m = jnp.where(mask, logits_b, -1e30)
return jax.nn.softmax(logits_m, axis=2)
@jaxtyped(typechecker=typechecker)
@jax.jit
def state_step(
state: Float[Array, "N"],
Pi: Float[Array, "H W 4"],
T: Float[Array, "4 N N"]
) -> Float[Array, "N"]:
"""
One step of differentiable state update:
- Pi: per-cell action probs (must be masked+normalized).
- T: transition matrices.
"""
H, W, _ = Pi.shape
N = H * W
weighted = state[:, None] * Pi.reshape(N, 4)
# einsum: ank, na -> k : for each possible dest k, sum over (action, from)
return jnp.einsum('ank,na->k', T, weighted)
@jax.jit
def rollout(
init_state: Float[Array, "N"],
logits_seq: Float[Array, "T 4"],
*,
T: Float[Array, "4 N N"],
walls: Bool[Array, "H W 4"]
) -> Float[Array, "T1 N"]:
"""
Unroll the state for T steps. At each step, mask softmax ensures feasibility.
Returns full state trajectory.
"""
def body(state, logits_t):
Pi = masked_softmax(logits_t, feasibility_mask(state, walls))
next_state = state_step(state, Pi, T)
return next_state, next_state
_, hist = jax.lax.scan(body, init_state, logits_seq)
return jnp.vstack([init_state, hist]) # shape (T+1, N)
@jaxtyped(typechecker=typechecker)
def terminal_loss(states: Float[Array, "T1 N"], goal_idx: int) -> Float[Array, ""]:
"""Loss = negative probability at goal at final step."""
return -states[-1, goal_idx]
@jaxtyped(typechecker=typechecker)
@jax.jit
def greedy_path(
logits_seq: Float[Array, "T 4"],
*,
start_idx: int,
T: Float[Array, "4 N N"],
walls: Bool[Array, "H W 4"],
W: int
) -> Array:
"""
Simulates a deterministic agent, always picking argmax action at each cell.
Returns sequence of visited indices (T+1,).
"""
N = T.shape[1]
def step(k, logits_t):
mask = feasibility_mask(jnp.zeros(N).at[k].set(1.0), walls)
Pi = masked_softmax(logits_t, mask)
i, j = index_to_coords(k, W)
a = jnp.argmax(Pi[i, j])
k_next = jnp.argmax(T[a, k])
return k_next, k_next
_, traj = jax.lax.scan(step, start_idx, logits_seq)
return jnp.hstack([jnp.array([start_idx]), traj])
#--- Visualisation -------------------------------------------------------------
class Visualiser:
"""
Subplot-0: maze (walls, start/goal, greedy path, live state-prob heatmap)
Subplot-1: policy (π(a|t) vs t)
Subplot-2: loss curve
"""
def __init__(self,
maze: np.ndarray,
start: Tuple[int, int],
goal: Tuple[int, int],
horizon: int):
import matplotlib.gridspec as gridspec
self.H, self.W = maze.shape
self.losses: list[float] = []
plt.ion()
self.fig = plt.figure(figsize=(14, 8))
gs = gridspec.GridSpec(
2, 2,
height_ratios=[2, 1],
width_ratios=[1, 1],
hspace=0.3,
wspace=0.25,
)
# ----------------- MAZE PANEL -----------------
self.ax_maze = self.fig.add_subplot(gs[0, 0])
# background: free = white, wall = black
self.ax_maze.imshow(1 - maze, cmap="gray", vmin=0, vmax=1)
self.ax_maze.set_xticks([]); self.ax_maze.set_yticks([])
# start / goal marks
self.ax_maze.text(start[1], start[0], "S",
c="lime", ha="center", va="center",
fontsize=14, weight="bold")
self.ax_maze.text(goal[1], goal[0], "G",
c="red", ha="center", va="center",
fontsize=14, weight="bold")
# probability-mass overlay (updated every call)
self.state_img = self.ax_maze.imshow(np.zeros_like(maze, dtype=float),
cmap="plasma",
alpha=0.55,
vmin=0.0, vmax=1.0,
interpolation="nearest")
# greedy-path line handle
self.path_line = None
# ----------------- POLICY PANEL -----------------
self.ax_policy = self.fig.add_subplot(gs[0, 1])
self.policy_img = self.ax_policy.imshow(
np.zeros((4, horizon)),
cmap="viridis",
vmin=0, vmax=1,
aspect="auto",
)
self.ax_policy.set_yticks(range(4))
self.ax_policy.set_yticklabels(["↑", "→", "↓", "←"])
self.ax_policy.set_xlabel("Timestep")
self.fig.colorbar(self.policy_img,
ax=self.ax_policy,
label="Action prob")
# ----------------- LOSS PANEL -----------------
self.ax_loss = self.fig.add_subplot(gs[1, :])
self.loss_line, = self.ax_loss.plot([], [])
self.ax_loss.set_xlabel("Optimiser step")
self.ax_loss.set_ylabel("Loss")
self.ax_loss.grid(True, ls=":")
# --------------------------------------------------------------------------
def update(self,
logits_seq: jnp.ndarray,
loss: float,
path: Array,
state_probs: Float[Array, "H W"]):
"""Redraw all live figures."""
# policy-heatmap
self.policy_img.set_data(jax.nn.softmax(logits_seq, axis=1).T)
# probability-mass overlay
self.state_img.set_data(state_probs)
self.state_img.set_clim(vmin=0.0, vmax=float(state_probs.max() + 1e-9))
# greedy path
path_coords = np.array([index_to_coords(k, self.W) for k in path])
y, x = path_coords[:, 0], path_coords[:, 1]
if self.path_line is not None:
self.path_line.remove()
self.path_line, = self.ax_maze.plot(x, y, "r-o", lw=2, ms=4)
self.ax_maze.set_title(f"Greedy path (unique cells ={len(np.unique(path))})")
# loss curve
self.losses.append(loss)
self.loss_line.set_data(range(len(self.losses)), self.losses)
self.ax_loss.relim(); self.ax_loss.autoscale_view()
# refresh
self.fig.canvas.draw_idle()
plt.pause(0.001)
#--- Main ----------------------------------------------------------------------
RAW_MAZE = np.array([
[0,0,0,0,1,0,0],
[0,1,1,0,1,0,0],
[0,0,0,0,0,0,0],
[1,1,0,1,1,1,0],
[0,0,0,0,0,1,0],
[0,1,1,1,0,0,0],
[0,0,0,0,0,1,0],
], dtype=bool)
START = (0,0)
GOAL = (6,6)
T_HORIZON = 16
LR = 0.01
STEPS = 200
STEP_INTERVAL = 0.05 # Set to >0 (seconds) to slow down optimisation steps
walls = make_wall_tensor(jnp.asarray(RAW_MAZE))
T = transition_matrices(walls)
H, W = RAW_MAZE.shape; N = H * W
start_idx = coords_to_index(*START, W)
goal_idx = coords_to_index(*GOAL, W)
init_state = jnp.zeros(N).at[start_idx].set(1.0)
key = jax.random.PRNGKey(0)
logits_seq = jax.random.normal(key, (T_HORIZON, 4))
opt = optax.adam(LR)
opt_state = opt.init(logits_seq)
vis = Visualiser(RAW_MAZE, START, GOAL, T_HORIZON)
@jax.jit
def loss_and_grad(params):
states = rollout(init_state, params, T=T, walls=walls)
return terminal_loss(states, goal_idx)
@jax.jit
def train_step(params, opt_state):
(loss_val, path_idx), grads = jax.value_and_grad(
lambda p: (loss_and_grad(p), greedy_path(p, start_idx=start_idx, T=T, walls=walls, W=W)),
has_aux=True)(params)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_val, path_idx
print("Optimising… (close window to abort)")
for step in range(STEPS):
start_time = time.time()
logits_seq, opt_state, loss_value, path = train_step(logits_seq, opt_state)
# Abort cleanly if the window was closed.
if not plt.fignum_exists(vis.fig.number):
break
states = rollout(init_state, logits_seq, T=T, walls=walls)
final_probs = states[-1].reshape(H, W) # (H, W)
vis.update(logits_seq, float(loss_value), path, final_probs)
if STEP_INTERVAL > 0.0:
elapsed = time.time() - start_time
to_sleep = STEP_INTERVAL - elapsed
if to_sleep > 0:
time.sleep(to_sleep)
print("Done.")
plt.ioff(); plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment