Skip to content

Instantly share code, notes, and snippets.

@etjones
etjones / jaxtyping_torch_demo.py
Last active February 4, 2025 21:53
Minimal example of using Jaxtyping and Beartype to enforce Pytorch.Tensor dimensions at runtime.
#!/usr/bin/env python3
# `uv` makes this easier. Run this with `uv run jaxtyping_torch_demo.py`.
# /// script
# dependencies = [
# "beartype>=0.19.0",
# "jaxtyping>=0.2.36",
# "numpy>=2.1.3",
# "torch>=2.5.1",