Skip to content

Instantly share code, notes, and snippets.

@kstoneriv3
Last active April 5, 2025 04:57
Show Gist options
  • Save kstoneriv3/9b0b9ef39b2c2bb9d3775f9d10a6d8ed to your computer and use it in GitHub Desktop.
Save kstoneriv3/9b0b9ef39b2c2bb9d3775f9d10a6d8ed to your computer and use it in GitHub Desktop.
type annotation for lazy-eval decorator
from functools import cache, wraps
from typing import Any, Callable, Generic, TypeVar, ParamSpec, cast
from beartype import beartype
T = TypeVar("T")
P = ParamSpec("P")
class DelayedOutput(Generic[T]):
def __init__(
self, func: Callable[..., T], *args: Any, **kwargs: Any
) -> None:
self.func = func
self.args = args
self.kwargs = kwargs
@cache
def eval(self) -> T:
evaluated_args = [
a.eval() if isinstance(a, DelayedOutput) else a for a in self.args
]
evaluated_kwargs = {
k: a.eval() if isinstance(a, DelayedOutput) else a
for k, a in self.kwargs.items()
}
return self.func(*evaluated_args, **evaluated_kwargs)
def delayed(func: Callable[P, T]) -> Callable[P, T]:
# Here, in the static typing, we pretend as if the function signatures are the same.
# In the runtime, we overwrite the function annotation to properly handle runtime check.
@wraps(func)
def delayed_func(*args: P.args, **kwargs: P.kwargs) -> T:
return cast(T, DelayedOutput(func, *args, **kwargs)) # pyright: ignore
delayed_func.__annotations__ = {
a: t | DelayedOutput[t]
for a, t in func.__annotations__.items()
if a != "return"
}
delayed_func.__annotations__["return"] = DelayedOutput[
func.__annotations__["return"]
]
return delayed_func
@beartype
@delayed
def add(x: int, y: int) -> int:
return x + y # at runtime, add() will be called with evaluated arguments
@delayed
def one() -> int:
return 1
@delayed
def a() -> str:
return "a"
# Test calls
for ret in [
add(1, 2), # OK
add(one(), 2), # OK: one() will evaluate to int
add(one(), one()), # OK
add(1, "a"), # NG: "a" is not an int
add(1, a()), # NG: a() is a delayed value that will produce str
add(one(), a()), # NG: a() eventually produces str, not int
]:
dout = cast(DelayedOutput, ret)
try:
print(dout.eval())
except:
print(dout.func.__name__, " - failed")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment