Skip to content

Instantly share code, notes, and snippets.

@ricardoV94
Last active September 28, 2025 20:41
Show Gist options
  • Save ricardoV94/a51785d4199b8c14ce0d337f2b437772 to your computer and use it in GitHub Desktop.
Save ricardoV94/a51785d4199b8c14ce0d337f2b437772 to your computer and use it in GitHub Desktop.
from functools import wraps
import numpy as np
import pytensor.tensor as pt
from pytensor import function
from pytensor.tensor import tensor
from pytensor.compile.function.types import Function
def pytensor_jit(_func=None, **compile_kwargs):
"""Compile pytensor function on demand
This decorator wraps a python function that takes pytensor inputs and returns pytensor expressions,
And compiles it on demand when inputs are provided, by introspecting their types (rank, broadcastable_pattern, and dtype)
Whenever new input types are provided, a new function is compiled and cached, and used subsequently for compatible types.
Limitations:
- Inputs are assumed to be numpy-like that can be converted with `np.asarray` without changing the semantics of the function
- Wrapped functions cannot be composed with other functions, unlike jitted functions in numba/jax.
- Argument by name (kwarg) are not supported.
Examples
--------
..code-block:: python
@pytensor_jit
def pt_cdf(x, mu, sigma):
return 0.5 * (1 + pt.erf((x - mu) / (sigma * 2 ** 0.5)))
pt_cdf(0.5, 1.0, [1, 2, 3]) # Compiles new function
pt_cdf(0.5, 1.0, [1, 2, 3]) # Reuses compiled function
pt_cdf([0.5, 1, 1.5], 1.0, 1.0) # Compiles new function
Inputs with default values will be constant folded when calling the function without specifying this argument.
..code-block:: python
@pytensor_jit
def pt_cdf(x, mu, sigma=1):
return 0.5 * (1 + pt.erf((x - mu) / (sigma * 2 ** 0.5)))
pt_cdf(0.5, 1.0) # Compiles function with constant-folded sigma
pt_cdf(0.5, 1.0) # Reuses compiled function
pt_cdf(0.5, 1.0, 1.0) # Compiles new function with variable sigma
Compilation kwargs can be passed to the decorator
..code-block:: python
@pytensor_jit(mode="NUMBA")
def pt_cdf(x, mu=0, sigma=1):
return 0.5 * (1 + pt.erf((x - mu) / (sigma * 2 ** 0.5)))
pt_cdf(0.5)
"""
def decorator(func):
# trust_input can be a problem if the user passes aliased inputs (including the same values twice)
# But it really reduces the overhead! Removing inplace rewrites would make this safe always (I think)
compile_kwargs.setdefault("trust_input", True)
# compile_kwargs.setdefault("mode", "NUMBA") # Or anything else you may want as default
signature_to_function: dict[tuple, Function] = {}
@wraps(func)
def inner_func(*args, signature_to_function=signature_to_function):
args = [np.asarray(a) for a in args]
# Signature is specified by number of inputs, their broadcastable pattern, and dtypes
signature = tuple((tuple(s == 1 for s in a.shape), a.dtype) for a in args)
try:
return signature_to_function[signature](*args)
except KeyError:
pass
# Need to compile a new function
print(f"Compiling new function for signature: {signature}")
symbolic_args = [
tensor(broadcastable=bcast_pattern, dtype=dtype)
for (bcast_pattern, dtype) in signature
]
symbolic_out = func(*symbolic_args)
signature_to_function[signature] = compiled_func = function(
symbolic_args, symbolic_out, **compile_kwargs
)
return compiled_func(*args)
return inner_func
if _func is None:
# Case: @pytensor_jit(key=value)
return decorator
else:
# Case: @pytensor_jit
return decorator(_func)
from functools import wraps
import numpy as np
import pytensor.tensor as pt
from pytensor import function
from pytensor.tensor import tensor
from pytensor.compile.function.types import Function, In
from pytensor.tensor.random.type import random_generator_type
def pytensor_rng_jit(_func=None, **compile_kwargs):
"""Compile pytensor function with RNG on demand."""
def decorator(func):
# trust_input can be a problem if the user passes aliased inputs (including the same values twice)
# But it really reduces the overhead! Removing inplace rewrites would make this safe always (I think)
compile_kwargs.setdefault("trust_input", True)
# compile_kwargs.setdefault("mode", "NUMBA") # Or anything else you may want as default
signature_to_function: dict[tuple, Function] = {}
@wraps(func)
def inner_func(*args, size, rng, signature_to_function=signature_to_function):
args = tuple(np.asarray(a) for a in args)
# Signature is specified by number of inputs, their broadcastable pattern, and dtypes
# As well as size.ndim (or None)
try:
if size is None:
signature = (
*((tuple(s == 1 for s in a.shape), a.dtype) for a in args),
None,
)
return signature_to_function[signature](*args, rng)
else:
size = np.asarray(size, dtype="int64")
signature = (
*((tuple(s == 1 for s in a.shape), a.dtype) for a in args),
size.ndim
)
return signature_to_function[signature](*args, size, rng)
except KeyError:
pass
# Need to compile a new function
print(f"Compiling new function for signature: {signature}")
symbolic_args = [
tensor(broadcastable=bcast_pattern, dtype=dtype)
for (bcast_pattern, dtype) in signature[:-1]
]
symbolic_size = None if size is None else tensor(shape=(size.ndim,), dtype="int64", name="size")
symbolic_rng = random_generator_type("rng")
symbolic_out = func(*symbolic_args, size=symbolic_size, rng=symbolic_rng)
# We allow PyTensor to modify the RNG
mutable_rng = In(symbolic_rng, mutable=True)
if size is None:
symbolic_inputs = [*symbolic_args, mutable_rng]
else:
symbolic_inputs = [*symbolic_args, symbolic_size, mutable_rng]
signature_to_function[signature] = compiled_func = function(
symbolic_inputs, symbolic_out, **compile_kwargs,
)
return compiled_func(*args, rng) if size is None else compiled_func(*args, size, rng)
return inner_func
if _func is None:
# Case: @pytensor_jit(key=value)
return decorator
else:
# Case: @pytensor_jit
return decorator(_func)
@pytensor_rng_jit
def normal_rvs(loc, scale, *, size, rng):
return pt.random.normal(loc, scale, size=size, rng=rng)
for i in range(2):
rng = np.random.default_rng(1)
print(
normal_rvs(0, 1, size=(2,), rng=rng),
normal_rvs(0, 1, size=(2,), rng=rng),
normal_rvs(0, 1, size=None, rng=rng),
normal_rvs([0, 1], 1, size=None, rng=rng),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment