Skip to content

Instantly share code, notes, and snippets.

View yberreby's full-sized avatar

Yohaï-Eliel Berreby yberreby

View GitHub Profile
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "torchvision==0.24.0",
# ]
# ///
"""
Tiny script to export the ordered list of ImageNet-1K classes to a JSON file in the same directory as itself.
Thanks to the dependency specification above, you can run it with:
"""Clean, fast single-layer binary network for n-bit parity.
Fully vectorized and JIT-compiled global k=1 weight flipping.
"""
import jax
import jax.numpy as jnp
from functools import partial
from time import time
from tqdm import tqdm
@yberreby
yberreby / 0_mlx_mup_demo.py
Last active October 21, 2025 22:41
A MLX-based foray into muP/µP (maximal update parameterization; cf. Tensor Programs V)
#!/usr/bin/env -S uv run
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "matplotlib",
# "polars>=1.34.0",
# "tqdm>=4.67.1",
# ]
# ///
#
@yberreby
yberreby / CLAUDE.md
Created June 25, 2025 16:31
A generic CLAUDE.md with coding methodology.

Methodology

Code like a hacker: concisely, with self-doubt, without fluff, without repeating yourself, keeping code as orthogonal as possible.

DRY

  • Repeating oneself is unacceptable.
  • If your LOCs look suspiciously similar, consider a loop or a lambda.
  • If you can refactor to follow a "data-driven" approach (e.g. list of dicts instead of ad-hoc code), consider doing so.
  • Don't be afraid of using tiny, local abstractions.
@yberreby
yberreby / direct_velocity_field_optimization.py
Created June 23, 2025 22:49
Optimization of a stationary velocity field with Adam in order to morph a shape into another. JAX-based interactive notebook, shared in `jupytext` format.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
@yberreby
yberreby / jax_lessons_learned.md
Created June 23, 2025 22:06
JAX Lessons Learned
  • Use jax.tree_util.Partial to pass partially-applied functions to JIT-compiled code
  • Use static_argnames, not static_argnums, whenever possible
# %% [markdown]
# ## Diffeomorphic T → Y morph with OT loss (ott-jax)
# %%
# %matplotlib widget
# %% 0 · Imports ---------------------------------------------------------------
import numpy as np, matplotlib.pyplot as plt
from skimage.draw import line
from skimage.measure import find_contours
# %%
# %matplotlib widget
# %%
# --- Imports ---------------------------------------------------------------
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import diags
@yberreby
yberreby / jax_masked_maze_backprop.py
Last active June 21, 2025 15:39
End-to-end differentiable maze example in JAX.
#!/usr/bin/env -S uv run --script
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "pyqt6", # For matplotlib backend
# "numpy",
# "optax>=0.2.5",
# "matplotlib",
# "jax[cuda12]==0.5.2", # Change for CPU.
# "jaxtyping>=0.3.2",
@yberreby
yberreby / bruteforce_3x3_lightsout_jax.py
Created June 17, 2025 07:34
Reasonably fast (several hundred million updates per second) vectorized JAX bruteforce for 3x3 'Lights Out!' boards. DISCLAIMER: Loosely-tested / reviewed.
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.17.2
# kernelspec: