Skip to content

Instantly share code, notes, and snippets.

@yberreby
Created June 17, 2025 07:34
Show Gist options
  • Save yberreby/83064d8ae6f68af7294abdc37ee84fea to your computer and use it in GitHub Desktop.
Save yberreby/83064d8ae6f68af7294abdc37ee84fea to your computer and use it in GitHub Desktop.
Reasonably fast (several hundred million updates per second) vectorized JAX bruteforce for 3x3 'Lights Out!' boards. DISCLAIMER: Loosely-tested / reviewed.
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.17.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
# %%
import time, matplotlib.pyplot as plt, numpy as np
import jax, jax.numpy as jnp
from jax import random, lax
# ─────────────────────────── Lights-Out primitives ──────────────────────────
def flip_kernel(side: int) -> jnp.ndarray:
eff = np.zeros((side, side, side, side), np.int8)
for i in range(side):
for j in range(side):
for di, dj in [(0,0),(1,0),(-1,0),(0,1),(0,-1)]:
ni, nj = i+di, j+dj
if 0 <= ni < side and 0 <= nj < side:
eff[i, j, ni, nj] = 1
return jnp.array(eff)
@jax.jit
def step_batch(boards, acts, eff):
flips = jnp.einsum('bij,ijkl->bkl', acts, eff)
return (boards + (flips & 1)) & 1 # XOR (mod-2)
@jax.jit
def solved(boards): return jnp.all(boards == 0, axis=(1, 2))
# ───────────────────────── Jitted scan body ────────────────────────────────
def make_body(B, side, H, eff):
"""
carry = boards (B,s,s) int8
attempts (B,) int32
hist (H,) int32
rng
"""
dnums = lax.ScatterDimensionNumbers(
update_window_dims=(1,), # updates last dim length 1
inserted_window_dims=(), # no broadcast dims
scatter_dims_to_operand_dims=(0,) # indices[:,0] maps to operand dim 0
)
@jax.jit
def body(carry, _):
boards, attempts, hist, key = carry
key, k1, k2 = random.split(key, 3)
acts = random.bernoulli(k1, 0.5, boards.shape).astype(jnp.int8)
nxt = step_batch(boards, acts, eff)
win = solved(nxt) # (B,)
# ----- histogram update via scatter_add --------------------------------
idx = jnp.minimum(attempts + 1, H - 1).astype(jnp.int32)
upd = win.astype(jnp.int32)[:, None] # (B,1)
idx2d = idx[:, None] # (B,1)
hist = lax.scatter_add(hist, idx2d, upd, dnums,
indices_are_sorted=False,
unique_indices=False)
# ----- hot-swap solved boards & attempts -------------------------------
new_b = random.bernoulli(k2, 0.5, boards.shape).astype(jnp.int8)
boards = jnp.where(win[:, None, None], new_b, nxt)
attempts = jnp.where(win, 0, attempts + 1)
return (boards, attempts, hist, key), None
return body
# ──────────────────────────── Driver ───────────────────────────────────────
def brute_hist(side=3, B=2048, steps=4000, H=400, seed=0):
eff = flip_kernel(side)
key = random.PRNGKey(seed)
key, sub = random.split(key)
boards = random.bernoulli(sub, 0.5, (B, side, side)).astype(jnp.int8)
attempts = jnp.zeros((B,), jnp.int32)
hist = jnp.zeros((H,), jnp.int32)
body = make_body(B, side, H, eff)
carry = (boards, attempts, hist, key)
t0 = time.time()
carry, _ = lax.scan(body, carry, xs=None, length=steps)
elapsed = time.time() - t0
_, _, hist, _ = carry
ups = B * steps / elapsed
print(f"{B*steps:,} board-updates in {elapsed:.2f}s → {ups:,.0f} updates/s")
return np.asarray(hist)
# ───────────────────────── Demo run & plot ─────────────────────────────────
if __name__ == "__main__":
side, B, steps, H = 3, 8192, 40000, 4000
hist = brute_hist(side, B, steps, H)
plt.figure()
xs = np.arange(1, H)
plt.bar(xs, hist[1:], width=0.9)
# plt.yscale("log")
plt.xlabel("Attempts to solve")
plt.ylabel("count (log)")
plt.title(f"{side}×{side} Lights-Out • {hist[1:].sum():,} solve events")
plt.tight_layout(); plt.show()
# %%
# =========================== VERIFICATION CODE =============================
def verify_game_mechanics():
"""Verify the basic game mechanics work correctly"""
side = 3
eff = flip_kernel(side)
# Test 1: Verify flip kernel - pressing center should flip 5 lights
print("=== Verification 1: Flip Kernel ===")
center_press = np.zeros((side, side), dtype=np.int8)
center_press[1, 1] = 1
effect = jnp.einsum('ij,ijkl->kl', center_press, eff)
expected = np.array([[0, 1, 0],
[1, 1, 1],
[0, 1, 0]], dtype=np.int8)
print(f"Center press effect:\n{effect}")
print(f"Expected:\n{expected}")
print(f"Correct: {np.array_equal(effect, expected)}\n")
# Test 2: Verify corner press
corner_press = np.zeros((side, side), dtype=np.int8)
corner_press[0, 0] = 1
effect = jnp.einsum('ij,ijkl->kl', corner_press, eff)
expected = np.array([[1, 1, 0],
[1, 0, 0],
[0, 0, 0]], dtype=np.int8)
print(f"Corner press effect:\n{effect}")
print(f"Expected:\n{expected}")
print(f"Correct: {np.array_equal(effect, expected)}\n")
def solve_specific_boards():
"""Test solver on specific boards with known solutions"""
side = 3
eff = flip_kernel(side)
# Create test boards
boards = []
# Board 1: All lights on - should be solvable by pressing all buttons
board1 = np.ones((side, side), dtype=np.int8)
boards.append(board1)
# Board 2: Simple pattern
board2 = np.array([[1, 0, 1],
[0, 1, 0],
[1, 0, 1]], dtype=np.int8)
boards.append(board2)
# Board 3: Random board
key = random.PRNGKey(42)
board3 = random.bernoulli(key, 0.5, (side, side)).astype(np.int8)
boards.append(board3)
# Solve each board
fig, axes = plt.subplots(3, 3, figsize=(10, 10))
for i, board in enumerate(boards):
# Original board
ax = axes[i, 0]
im = ax.imshow(board, cmap='gray_r', vmin=0, vmax=1)
ax.set_title(f'Board {i+1}: Initial')
ax.set_xticks(range(side))
ax.set_yticks(range(side))
ax.grid(True, alpha=0.3)
# Try to solve with random actions
test_board = board[None, :, :] # Add batch dimension
found_solution = False
key = random.PRNGKey(i)
for attempt in range(1000): # Try up to 1000 random actions
key, subkey = random.split(key)
actions = random.bernoulli(subkey, 0.5, (1, side, side)).astype(np.int8)
result = step_batch(test_board, actions, eff)
if solved(result)[0]:
found_solution = True
break
# Show actions
ax = axes[i, 1]
if found_solution:
im = ax.imshow(actions[0], cmap='RdBu_r', vmin=0, vmax=1)
ax.set_title(f'Actions (attempt {attempt+1})')
# Add text annotations
for y in range(side):
for x in range(side):
if actions[0, y, x]:
ax.text(x, y, '✓', ha='center', va='center',
color='white', fontsize=16, weight='bold')
else:
ax.text(0.5, 0.5, 'No solution\nfound', ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.set_xlim(-0.5, 2.5)
ax.set_ylim(-0.5, 2.5)
ax.set_xticks(range(side))
ax.set_yticks(range(side))
ax.grid(True, alpha=0.3)
# Show result
ax = axes[i, 2]
if found_solution:
im = ax.imshow(result[0], cmap='gray_r', vmin=0, vmax=1)
ax.set_title(f'Result: {"SOLVED ✓" if solved(result)[0] else "Not solved ✗"}')
else:
ax.text(0.5, 0.5, 'N/A', ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.set_xlim(-0.5, 2.5)
ax.set_ylim(-0.5, 2.5)
ax.set_xticks(range(side))
ax.set_yticks(range(side))
ax.grid(True, alpha=0.3)
plt.suptitle('Lights-Out Solver Verification', fontsize=16)
plt.tight_layout()
plt.show()
# Run all verifications
if __name__ == "__main__":
verify_game_mechanics()
solve_specific_boards()
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment