Skip to content

Instantly share code, notes, and snippets.

View yberreby's full-sized avatar

Yohaï-Eliel Berreby yberreby

View GitHub Profile
@yberreby
yberreby / CLAUDE.md
Created June 25, 2025 16:31
A generic CLAUDE.md with coding methodology.

Methodology

Code like a hacker: concisely, with self-doubt, without fluff, without repeating yourself, keeping code as orthogonal as possible.

DRY

  • Repeating oneself is unacceptable.
  • If your LOCs look suspiciously similar, consider a loop or a lambda.
  • If you can refactor to follow a "data-driven" approach (e.g. list of dicts instead of ad-hoc code), consider doing so.
  • Don't be afraid of using tiny, local abstractions.
@yberreby
yberreby / direct_velocity_field_optimization.py
Created June 23, 2025 22:49
Optimization of a stationary velocity field with Adam in order to morph a shape into another. JAX-based interactive notebook, shared in `jupytext` format.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
@yberreby
yberreby / jax_lessons_learned.md
Created June 23, 2025 22:06
JAX Lessons Learned
  • Use jax.tree_util.Partial to pass partially-applied functions to JIT-compiled code
  • Use static_argnames, not static_argnums, whenever possible
# %% [markdown]
# ## Diffeomorphic T → Y morph with OT loss (ott-jax)
# %%
# %matplotlib widget
# %% 0 · Imports ---------------------------------------------------------------
import numpy as np, matplotlib.pyplot as plt
from skimage.draw import line
from skimage.measure import find_contours
# %%
# %matplotlib widget
# %%
# --- Imports ---------------------------------------------------------------
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import diags
@yberreby
yberreby / jax_masked_maze_backprop.py
Last active June 21, 2025 15:39
End-to-end differentiable maze example in JAX.
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "pyqt6", # For matplotlib backend
# "numpy",
# "optax>=0.2.5",
# "matplotlib",
# "jax[cuda12]==0.5.2", # Change for CPU.
# "jaxtyping>=0.3.2",
@yberreby
yberreby / bruteforce_3x3_lightsout_jax.py
Created June 17, 2025 07:34
Reasonably fast (several hundred million updates per second) vectorized JAX bruteforce for 3x3 'Lights Out!' boards. DISCLAIMER: Loosely-tested / reviewed.
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.17.2
# kernelspec:
@yberreby
yberreby / bench_jax.py
Last active April 20, 2025 01:16
Quick JAX vs Triton comparison on a toy kernel. Outputs from runs on a RTX 4060 Mobile.
import functools, time, jax, jax.numpy as jnp
jax.config.update("jax_default_matmul_precision", "tensorfloat32")
SQRT2_OVER_PI = 0.7978845608028654
# ----------------------------------------------------------------------
def gelu_fast(x):
u = SQRT2_OVER_PI * (x + 0.044715 * x * x * x)
return 0.5 * x * (1. + jnp.tanh(u))
@yberreby
yberreby / 0_geom_opt_cmp.py
Last active April 20, 2025 01:16
Quick comparison of a few optimizers on the 2-simplex: GD, Adam, Mirror Descent, Adam in mirror space, LBFGS in mirror space
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "jax==0.5.0",
# "jaxopt>=0.8.5",
# "optax>=0.2.4",
# "matplotlib>=3.10.1",
# "pyqt6>=6.9.0", # for matplotlib gui
# ]
@yberreby
yberreby / cahn_hilliard_literate.py
Last active April 19, 2025 00:06
Cahn-Hilliard - Run this for visually-interesting non-linear ODE behavior, in an animated plot.
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "jax==0.5.0",
# "matplotlib>=3.10.1",
# "pyqt6>=6.9.0", # for matplotlib gui
# ]
# ///