Last active
September 28, 2025 20:41
-
-
Save ricardoV94/a51785d4199b8c14ce0d337f2b437772 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from functools import wraps | |
| import numpy as np | |
| import pytensor.tensor as pt | |
| from pytensor import function | |
| from pytensor.tensor import tensor | |
| from pytensor.compile.function.types import Function | |
| def pytensor_jit(_func=None, **compile_kwargs): | |
| """Compile pytensor function on demand | |
| This decorator wraps a python function that takes pytensor inputs and returns pytensor expressions, | |
| And compiles it on demand when inputs are provided, by introspecting their types (rank, broadcastable_pattern, and dtype) | |
| Whenever new input types are provided, a new function is compiled and cached, and used subsequently for compatible types. | |
| Limitations: | |
| - Inputs are assumed to be numpy-like that can be converted with `np.asarray` without changing the semantics of the function | |
| - Wrapped functions cannot be composed with other functions, unlike jitted functions in numba/jax. | |
| - Argument by name (kwarg) are not supported. | |
| Examples | |
| -------- | |
| ..code-block:: python | |
| @pytensor_jit | |
| def pt_cdf(x, mu, sigma): | |
| return 0.5 * (1 + pt.erf((x - mu) / (sigma * 2 ** 0.5))) | |
| pt_cdf(0.5, 1.0, [1, 2, 3]) # Compiles new function | |
| pt_cdf(0.5, 1.0, [1, 2, 3]) # Reuses compiled function | |
| pt_cdf([0.5, 1, 1.5], 1.0, 1.0) # Compiles new function | |
| Inputs with default values will be constant folded when calling the function without specifying this argument. | |
| ..code-block:: python | |
| @pytensor_jit | |
| def pt_cdf(x, mu, sigma=1): | |
| return 0.5 * (1 + pt.erf((x - mu) / (sigma * 2 ** 0.5))) | |
| pt_cdf(0.5, 1.0) # Compiles function with constant-folded sigma | |
| pt_cdf(0.5, 1.0) # Reuses compiled function | |
| pt_cdf(0.5, 1.0, 1.0) # Compiles new function with variable sigma | |
| Compilation kwargs can be passed to the decorator | |
| ..code-block:: python | |
| @pytensor_jit(mode="NUMBA") | |
| def pt_cdf(x, mu=0, sigma=1): | |
| return 0.5 * (1 + pt.erf((x - mu) / (sigma * 2 ** 0.5))) | |
| pt_cdf(0.5) | |
| """ | |
| def decorator(func): | |
| # trust_input can be a problem if the user passes aliased inputs (including the same values twice) | |
| # But it really reduces the overhead! Removing inplace rewrites would make this safe always (I think) | |
| compile_kwargs.setdefault("trust_input", True) | |
| # compile_kwargs.setdefault("mode", "NUMBA") # Or anything else you may want as default | |
| signature_to_function: dict[tuple, Function] = {} | |
| @wraps(func) | |
| def inner_func(*args, signature_to_function=signature_to_function): | |
| args = [np.asarray(a) for a in args] | |
| # Signature is specified by number of inputs, their broadcastable pattern, and dtypes | |
| signature = tuple((tuple(s == 1 for s in a.shape), a.dtype) for a in args) | |
| try: | |
| return signature_to_function[signature](*args) | |
| except KeyError: | |
| pass | |
| # Need to compile a new function | |
| print(f"Compiling new function for signature: {signature}") | |
| symbolic_args = [ | |
| tensor(broadcastable=bcast_pattern, dtype=dtype) | |
| for (bcast_pattern, dtype) in signature | |
| ] | |
| symbolic_out = func(*symbolic_args) | |
| signature_to_function[signature] = compiled_func = function( | |
| symbolic_args, symbolic_out, **compile_kwargs | |
| ) | |
| return compiled_func(*args) | |
| return inner_func | |
| if _func is None: | |
| # Case: @pytensor_jit(key=value) | |
| return decorator | |
| else: | |
| # Case: @pytensor_jit | |
| return decorator(_func) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from functools import wraps | |
| import numpy as np | |
| import pytensor.tensor as pt | |
| from pytensor import function | |
| from pytensor.tensor import tensor | |
| from pytensor.compile.function.types import Function, In | |
| from pytensor.tensor.random.type import random_generator_type | |
| def pytensor_rng_jit(_func=None, **compile_kwargs): | |
| """Compile pytensor function with RNG on demand.""" | |
| def decorator(func): | |
| # trust_input can be a problem if the user passes aliased inputs (including the same values twice) | |
| # But it really reduces the overhead! Removing inplace rewrites would make this safe always (I think) | |
| compile_kwargs.setdefault("trust_input", True) | |
| # compile_kwargs.setdefault("mode", "NUMBA") # Or anything else you may want as default | |
| signature_to_function: dict[tuple, Function] = {} | |
| @wraps(func) | |
| def inner_func(*args, size, rng, signature_to_function=signature_to_function): | |
| args = tuple(np.asarray(a) for a in args) | |
| # Signature is specified by number of inputs, their broadcastable pattern, and dtypes | |
| # As well as size.ndim (or None) | |
| try: | |
| if size is None: | |
| signature = ( | |
| *((tuple(s == 1 for s in a.shape), a.dtype) for a in args), | |
| None, | |
| ) | |
| return signature_to_function[signature](*args, rng) | |
| else: | |
| size = np.asarray(size, dtype="int64") | |
| signature = ( | |
| *((tuple(s == 1 for s in a.shape), a.dtype) for a in args), | |
| size.ndim | |
| ) | |
| return signature_to_function[signature](*args, size, rng) | |
| except KeyError: | |
| pass | |
| # Need to compile a new function | |
| print(f"Compiling new function for signature: {signature}") | |
| symbolic_args = [ | |
| tensor(broadcastable=bcast_pattern, dtype=dtype) | |
| for (bcast_pattern, dtype) in signature[:-1] | |
| ] | |
| symbolic_size = None if size is None else tensor(shape=(size.ndim,), dtype="int64", name="size") | |
| symbolic_rng = random_generator_type("rng") | |
| symbolic_out = func(*symbolic_args, size=symbolic_size, rng=symbolic_rng) | |
| # We allow PyTensor to modify the RNG | |
| mutable_rng = In(symbolic_rng, mutable=True) | |
| if size is None: | |
| symbolic_inputs = [*symbolic_args, mutable_rng] | |
| else: | |
| symbolic_inputs = [*symbolic_args, symbolic_size, mutable_rng] | |
| signature_to_function[signature] = compiled_func = function( | |
| symbolic_inputs, symbolic_out, **compile_kwargs, | |
| ) | |
| return compiled_func(*args, rng) if size is None else compiled_func(*args, size, rng) | |
| return inner_func | |
| if _func is None: | |
| # Case: @pytensor_jit(key=value) | |
| return decorator | |
| else: | |
| # Case: @pytensor_jit | |
| return decorator(_func) | |
| @pytensor_rng_jit | |
| def normal_rvs(loc, scale, *, size, rng): | |
| return pt.random.normal(loc, scale, size=size, rng=rng) | |
| for i in range(2): | |
| rng = np.random.default_rng(1) | |
| print( | |
| normal_rvs(0, 1, size=(2,), rng=rng), | |
| normal_rvs(0, 1, size=(2,), rng=rng), | |
| normal_rvs(0, 1, size=None, rng=rng), | |
| normal_rvs([0, 1], 1, size=None, rng=rng), | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment