Created
June 17, 2025 07:34
-
-
Save yberreby/83064d8ae6f68af7294abdc37ee84fea to your computer and use it in GitHub Desktop.
Reasonably fast (several hundred million updates per second) vectorized JAX bruteforce for 3x3 'Lights Out!' boards. DISCLAIMER: Loosely-tested / reviewed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --- | |
# jupyter: | |
# jupytext: | |
# formats: ipynb,py:percent | |
# text_representation: | |
# extension: .py | |
# format_name: percent | |
# format_version: '1.3' | |
# jupytext_version: 1.17.2 | |
# kernelspec: | |
# display_name: Python 3 (ipykernel) | |
# language: python | |
# name: python3 | |
# --- | |
# %% | |
import time, matplotlib.pyplot as plt, numpy as np | |
import jax, jax.numpy as jnp | |
from jax import random, lax | |
# ─────────────────────────── Lights-Out primitives ────────────────────────── | |
def flip_kernel(side: int) -> jnp.ndarray: | |
eff = np.zeros((side, side, side, side), np.int8) | |
for i in range(side): | |
for j in range(side): | |
for di, dj in [(0,0),(1,0),(-1,0),(0,1),(0,-1)]: | |
ni, nj = i+di, j+dj | |
if 0 <= ni < side and 0 <= nj < side: | |
eff[i, j, ni, nj] = 1 | |
return jnp.array(eff) | |
@jax.jit | |
def step_batch(boards, acts, eff): | |
flips = jnp.einsum('bij,ijkl->bkl', acts, eff) | |
return (boards + (flips & 1)) & 1 # XOR (mod-2) | |
@jax.jit | |
def solved(boards): return jnp.all(boards == 0, axis=(1, 2)) | |
# ───────────────────────── Jitted scan body ──────────────────────────────── | |
def make_body(B, side, H, eff): | |
""" | |
carry = boards (B,s,s) int8 | |
attempts (B,) int32 | |
hist (H,) int32 | |
rng | |
""" | |
dnums = lax.ScatterDimensionNumbers( | |
update_window_dims=(1,), # updates last dim length 1 | |
inserted_window_dims=(), # no broadcast dims | |
scatter_dims_to_operand_dims=(0,) # indices[:,0] maps to operand dim 0 | |
) | |
@jax.jit | |
def body(carry, _): | |
boards, attempts, hist, key = carry | |
key, k1, k2 = random.split(key, 3) | |
acts = random.bernoulli(k1, 0.5, boards.shape).astype(jnp.int8) | |
nxt = step_batch(boards, acts, eff) | |
win = solved(nxt) # (B,) | |
# ----- histogram update via scatter_add -------------------------------- | |
idx = jnp.minimum(attempts + 1, H - 1).astype(jnp.int32) | |
upd = win.astype(jnp.int32)[:, None] # (B,1) | |
idx2d = idx[:, None] # (B,1) | |
hist = lax.scatter_add(hist, idx2d, upd, dnums, | |
indices_are_sorted=False, | |
unique_indices=False) | |
# ----- hot-swap solved boards & attempts ------------------------------- | |
new_b = random.bernoulli(k2, 0.5, boards.shape).astype(jnp.int8) | |
boards = jnp.where(win[:, None, None], new_b, nxt) | |
attempts = jnp.where(win, 0, attempts + 1) | |
return (boards, attempts, hist, key), None | |
return body | |
# ──────────────────────────── Driver ─────────────────────────────────────── | |
def brute_hist(side=3, B=2048, steps=4000, H=400, seed=0): | |
eff = flip_kernel(side) | |
key = random.PRNGKey(seed) | |
key, sub = random.split(key) | |
boards = random.bernoulli(sub, 0.5, (B, side, side)).astype(jnp.int8) | |
attempts = jnp.zeros((B,), jnp.int32) | |
hist = jnp.zeros((H,), jnp.int32) | |
body = make_body(B, side, H, eff) | |
carry = (boards, attempts, hist, key) | |
t0 = time.time() | |
carry, _ = lax.scan(body, carry, xs=None, length=steps) | |
elapsed = time.time() - t0 | |
_, _, hist, _ = carry | |
ups = B * steps / elapsed | |
print(f"{B*steps:,} board-updates in {elapsed:.2f}s → {ups:,.0f} updates/s") | |
return np.asarray(hist) | |
# ───────────────────────── Demo run & plot ───────────────────────────────── | |
if __name__ == "__main__": | |
side, B, steps, H = 3, 8192, 40000, 4000 | |
hist = brute_hist(side, B, steps, H) | |
plt.figure() | |
xs = np.arange(1, H) | |
plt.bar(xs, hist[1:], width=0.9) | |
# plt.yscale("log") | |
plt.xlabel("Attempts to solve") | |
plt.ylabel("count (log)") | |
plt.title(f"{side}×{side} Lights-Out • {hist[1:].sum():,} solve events") | |
plt.tight_layout(); plt.show() | |
# %% | |
# =========================== VERIFICATION CODE ============================= | |
def verify_game_mechanics(): | |
"""Verify the basic game mechanics work correctly""" | |
side = 3 | |
eff = flip_kernel(side) | |
# Test 1: Verify flip kernel - pressing center should flip 5 lights | |
print("=== Verification 1: Flip Kernel ===") | |
center_press = np.zeros((side, side), dtype=np.int8) | |
center_press[1, 1] = 1 | |
effect = jnp.einsum('ij,ijkl->kl', center_press, eff) | |
expected = np.array([[0, 1, 0], | |
[1, 1, 1], | |
[0, 1, 0]], dtype=np.int8) | |
print(f"Center press effect:\n{effect}") | |
print(f"Expected:\n{expected}") | |
print(f"Correct: {np.array_equal(effect, expected)}\n") | |
# Test 2: Verify corner press | |
corner_press = np.zeros((side, side), dtype=np.int8) | |
corner_press[0, 0] = 1 | |
effect = jnp.einsum('ij,ijkl->kl', corner_press, eff) | |
expected = np.array([[1, 1, 0], | |
[1, 0, 0], | |
[0, 0, 0]], dtype=np.int8) | |
print(f"Corner press effect:\n{effect}") | |
print(f"Expected:\n{expected}") | |
print(f"Correct: {np.array_equal(effect, expected)}\n") | |
def solve_specific_boards(): | |
"""Test solver on specific boards with known solutions""" | |
side = 3 | |
eff = flip_kernel(side) | |
# Create test boards | |
boards = [] | |
# Board 1: All lights on - should be solvable by pressing all buttons | |
board1 = np.ones((side, side), dtype=np.int8) | |
boards.append(board1) | |
# Board 2: Simple pattern | |
board2 = np.array([[1, 0, 1], | |
[0, 1, 0], | |
[1, 0, 1]], dtype=np.int8) | |
boards.append(board2) | |
# Board 3: Random board | |
key = random.PRNGKey(42) | |
board3 = random.bernoulli(key, 0.5, (side, side)).astype(np.int8) | |
boards.append(board3) | |
# Solve each board | |
fig, axes = plt.subplots(3, 3, figsize=(10, 10)) | |
for i, board in enumerate(boards): | |
# Original board | |
ax = axes[i, 0] | |
im = ax.imshow(board, cmap='gray_r', vmin=0, vmax=1) | |
ax.set_title(f'Board {i+1}: Initial') | |
ax.set_xticks(range(side)) | |
ax.set_yticks(range(side)) | |
ax.grid(True, alpha=0.3) | |
# Try to solve with random actions | |
test_board = board[None, :, :] # Add batch dimension | |
found_solution = False | |
key = random.PRNGKey(i) | |
for attempt in range(1000): # Try up to 1000 random actions | |
key, subkey = random.split(key) | |
actions = random.bernoulli(subkey, 0.5, (1, side, side)).astype(np.int8) | |
result = step_batch(test_board, actions, eff) | |
if solved(result)[0]: | |
found_solution = True | |
break | |
# Show actions | |
ax = axes[i, 1] | |
if found_solution: | |
im = ax.imshow(actions[0], cmap='RdBu_r', vmin=0, vmax=1) | |
ax.set_title(f'Actions (attempt {attempt+1})') | |
# Add text annotations | |
for y in range(side): | |
for x in range(side): | |
if actions[0, y, x]: | |
ax.text(x, y, '✓', ha='center', va='center', | |
color='white', fontsize=16, weight='bold') | |
else: | |
ax.text(0.5, 0.5, 'No solution\nfound', ha='center', va='center', | |
transform=ax.transAxes, fontsize=12) | |
ax.set_xlim(-0.5, 2.5) | |
ax.set_ylim(-0.5, 2.5) | |
ax.set_xticks(range(side)) | |
ax.set_yticks(range(side)) | |
ax.grid(True, alpha=0.3) | |
# Show result | |
ax = axes[i, 2] | |
if found_solution: | |
im = ax.imshow(result[0], cmap='gray_r', vmin=0, vmax=1) | |
ax.set_title(f'Result: {"SOLVED ✓" if solved(result)[0] else "Not solved ✗"}') | |
else: | |
ax.text(0.5, 0.5, 'N/A', ha='center', va='center', | |
transform=ax.transAxes, fontsize=12) | |
ax.set_xlim(-0.5, 2.5) | |
ax.set_ylim(-0.5, 2.5) | |
ax.set_xticks(range(side)) | |
ax.set_yticks(range(side)) | |
ax.grid(True, alpha=0.3) | |
plt.suptitle('Lights-Out Solver Verification', fontsize=16) | |
plt.tight_layout() | |
plt.show() | |
# Run all verifications | |
if __name__ == "__main__": | |
verify_game_mechanics() | |
solve_specific_boards() | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment