Skip to content

Instantly share code, notes, and snippets.

@etjones
Last active February 4, 2025 21:53
Show Gist options
  • Save etjones/ffd823401c6aaf6e1bd2496a6feb96b4 to your computer and use it in GitHub Desktop.
Save etjones/ffd823401c6aaf6e1bd2496a6feb96b4 to your computer and use it in GitHub Desktop.
Minimal example of using Jaxtyping and Beartype to enforce Pytorch.Tensor dimensions at runtime.
#!/usr/bin/env python3
# `uv` makes this easier. Run this with `uv run jaxtyping_torch_demo.py`.
# /// script
# dependencies = [
# "beartype>=0.19.0",
# "jaxtyping>=0.2.36",
# "numpy>=2.1.3",
# "torch>=2.5.1",
# ]
# ///
# Tell ruff not to worry about the F722 error
# (See: https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error)
# ruff: noqa: F722
import torch
from jaxtyping import Int, Float, jaxtyped
from beartype import beartype as typechecker
# Also possible:
# from typeguard import typechecked as typechecker
#
# (But beartype has been doing a lot of work to integrate with JaxTyping, so
# that worked well for me.)
@jaxtyped(typechecker=typechecker)
def accepts_torch(
batch: int,
# Accept a 4d float tensor, with the batch dimension interpolated from the `batch` argument
x: Float[torch.Tensor, "{batch} channel height width"],
) -> None:
print(f"accepts_torch() received Tensor of shape {x.shape}")
@jaxtyped(typechecker=typechecker)
def returns_torch(
# Accept a 3d float tensor, with the batch dimension interpolated from the `batch` argument
x: Float[torch.Tensor, "channel height width"],
) -> Float[torch.Tensor, "channel height width"]:
return x
def main():
print("Hello from jaxtyping-demo!")
# Succeeds, and enforces the correct shape. Note that if you change
# the value of `batch` or the first dimension of `x`, you'll get a
# runtime error. (Yay!)
batch = 2
x = torch.randn(2, 3, 4, 5)
accepts_torch(batch, x)
y = returns_torch(torch.randn(3, 4, 5))
print(f"returns_torch() returned a Tensor of dimension: {y.shape}")
# Fails; needs a 4d tensor, but we only supply a 3d tensor
y = torch.randn(batch, 3, 4)
print("The following call to accepts_torch() will fail:"
accepts_torch(batch, y)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment