Last active
April 5, 2025 04:57
-
-
Save kstoneriv3/9b0b9ef39b2c2bb9d3775f9d10a6d8ed to your computer and use it in GitHub Desktop.
type annotation for lazy-eval decorator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import cache, wraps | |
from typing import Any, Callable, Generic, TypeVar, ParamSpec, cast | |
from beartype import beartype | |
T = TypeVar("T") | |
P = ParamSpec("P") | |
class DelayedOutput(Generic[T]): | |
def __init__( | |
self, func: Callable[..., T], *args: Any, **kwargs: Any | |
) -> None: | |
self.func = func | |
self.args = args | |
self.kwargs = kwargs | |
@cache | |
def eval(self) -> T: | |
evaluated_args = [ | |
a.eval() if isinstance(a, DelayedOutput) else a for a in self.args | |
] | |
evaluated_kwargs = { | |
k: a.eval() if isinstance(a, DelayedOutput) else a | |
for k, a in self.kwargs.items() | |
} | |
return self.func(*evaluated_args, **evaluated_kwargs) | |
def delayed(func: Callable[P, T]) -> Callable[P, T]: | |
# Here, in the static typing, we pretend as if the function signatures are the same. | |
# In the runtime, we overwrite the function annotation to properly handle runtime check. | |
@wraps(func) | |
def delayed_func(*args: P.args, **kwargs: P.kwargs) -> T: | |
return cast(T, DelayedOutput(func, *args, **kwargs)) # pyright: ignore | |
delayed_func.__annotations__ = { | |
a: t | DelayedOutput[t] | |
for a, t in func.__annotations__.items() | |
if a != "return" | |
} | |
delayed_func.__annotations__["return"] = DelayedOutput[ | |
func.__annotations__["return"] | |
] | |
return delayed_func | |
@beartype | |
@delayed | |
def add(x: int, y: int) -> int: | |
return x + y # at runtime, add() will be called with evaluated arguments | |
@delayed | |
def one() -> int: | |
return 1 | |
@delayed | |
def a() -> str: | |
return "a" | |
# Test calls | |
for ret in [ | |
add(1, 2), # OK | |
add(one(), 2), # OK: one() will evaluate to int | |
add(one(), one()), # OK | |
add(1, "a"), # NG: "a" is not an int | |
add(1, a()), # NG: a() is a delayed value that will produce str | |
add(one(), a()), # NG: a() eventually produces str, not int | |
]: | |
dout = cast(DelayedOutput, ret) | |
try: | |
print(dout.eval()) | |
except: | |
print(dout.func.__name__, " - failed") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment