Last active
February 4, 2025 21:53
-
-
Save etjones/ffd823401c6aaf6e1bd2496a6feb96b4 to your computer and use it in GitHub Desktop.
Minimal example of using Jaxtyping and Beartype to enforce Pytorch.Tensor dimensions at runtime.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# `uv` makes this easier. Run this with `uv run jaxtyping_torch_demo.py`. | |
# /// script | |
# dependencies = [ | |
# "beartype>=0.19.0", | |
# "jaxtyping>=0.2.36", | |
# "numpy>=2.1.3", | |
# "torch>=2.5.1", | |
# ] | |
# /// | |
# Tell ruff not to worry about the F722 error | |
# (See: https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error) | |
# ruff: noqa: F722 | |
import torch | |
from jaxtyping import Int, Float, jaxtyped | |
from beartype import beartype as typechecker | |
# Also possible: | |
# from typeguard import typechecked as typechecker | |
# | |
# (But beartype has been doing a lot of work to integrate with JaxTyping, so | |
# that worked well for me.) | |
@jaxtyped(typechecker=typechecker) | |
def accepts_torch( | |
batch: int, | |
# Accept a 4d float tensor, with the batch dimension interpolated from the `batch` argument | |
x: Float[torch.Tensor, "{batch} channel height width"], | |
) -> None: | |
print(f"accepts_torch() received Tensor of shape {x.shape}") | |
@jaxtyped(typechecker=typechecker) | |
def returns_torch( | |
# Accept a 3d float tensor, with the batch dimension interpolated from the `batch` argument | |
x: Float[torch.Tensor, "channel height width"], | |
) -> Float[torch.Tensor, "channel height width"]: | |
return x | |
def main(): | |
print("Hello from jaxtyping-demo!") | |
# Succeeds, and enforces the correct shape. Note that if you change | |
# the value of `batch` or the first dimension of `x`, you'll get a | |
# runtime error. (Yay!) | |
batch = 2 | |
x = torch.randn(2, 3, 4, 5) | |
accepts_torch(batch, x) | |
y = returns_torch(torch.randn(3, 4, 5)) | |
print(f"returns_torch() returned a Tensor of dimension: {y.shape}") | |
# Fails; needs a 4d tensor, but we only supply a 3d tensor | |
y = torch.randn(batch, 3, 4) | |
print("The following call to accepts_torch() will fail:" | |
accepts_torch(batch, y) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment