Last active
August 20, 2024 20:55
-
-
Save CodeByAidan/1401ec77e782c47867879d400589da13 to your computer and use it in GitHub Desktop.
Enforce type checking for a lambda/callable object. Automatically adds a function signature, and doc using dunder/magic methods! (Python 3.12+)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections.abc import Callable | |
from functools import wraps | |
class typed[T, U]: | |
""" | |
A decorator class that enforces type checking for callable objects. | |
This class takes in a callable and ensures that the arguments passed to the callable | |
match the specified types. It returns the callable with the same type signature, but | |
without any additional runtime type enforcement. | |
Example Usage: | |
-------------- | |
:: | |
>>> add1point5 = typed(int, float, lambda x: x + 1.5) # (int) -> float | |
>>> add1point5(42) | |
43.5 | |
>>> concat = typed(str, str, str, lambda s1, s2: s1 + s2) # (str, str) -> str | |
>>> concat("foo", "bar") | |
'foobar' | |
""" | |
def __new__(cls, f: Callable[..., U], *Ts: type) -> Callable[..., U]: | |
""" | |
Creates a new instance of the typed class. | |
Parameters:: | |
----------- | |
- ``f``: A callable that takes arguments of types `*Ts` and returns a value. | |
- ``Ts``: The types of the arguments that the callable accepts. | |
Returns:: | |
-------- | |
- A callable with the same signature as `f`. | |
Example Usage: | |
-------------- | |
:: | |
>>> add1point5 = typed(int, float, lambda x: x + 1.5) | |
>>> add1point5(42) | |
43.5 | |
""" | |
@wraps(f) | |
def wrapped_function(*args: T) -> U: | |
return f(*args) | |
# construct the function signature and annotations manually | |
arg_annotations: dict[str, str] = { | |
f"arg{i}": t.__name__ for i, t in enumerate(Ts[:-1]) | |
} | |
return_annotation: str = Ts[-1].__name__ | |
wrapped_function.__annotations__ = { | |
**arg_annotations, | |
"return": return_annotation, | |
} | |
signature_str: str = ( | |
f"({', '.join(arg_annotations.values())}) -> {return_annotation}" | |
) | |
wrapped_function.__doc__ = f"def {f.__name__}{signature_str}" | |
return wrapped_function | |
if __name__ == "__main__": | |
add1point5: Callable[[int], float] = typed( | |
lambda x: x + 1.5, int, float | |
) # (int) -> float | |
print(add1point5.__doc__) # def <lambda>(int) -> float | |
print(add1point5.__annotations__) # {'arg0': 'int', 'return': 'float'} | |
print(add1point5(42), end="\n\n") # 43.5 | |
concat: Callable[[str, str], str] = typed( | |
lambda s1, s2: s1 + s2, str, str, str | |
) # (str, str) -> str | |
print(concat.__doc__) # def <lambda>(str, str) -> str | |
print(concat.__annotations__) # {'arg0': 'str', 'arg1': 'str', 'return': 'str'} | |
print(concat("foo", "bar")) # foobar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def <lambda>(int) -> float | |
{'arg0': 'int', 'return': 'float'} | |
43.5 | |
def <lambda>(str, str) -> str | |
{'arg0': 'str', 'arg1': 'str', 'return': 'str'} | |
foobar |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment