Skip to content

Instantly share code, notes, and snippets.

@CodeByAidan
Last active August 20, 2024 20:55
Show Gist options
  • Save CodeByAidan/1401ec77e782c47867879d400589da13 to your computer and use it in GitHub Desktop.
Save CodeByAidan/1401ec77e782c47867879d400589da13 to your computer and use it in GitHub Desktop.
Enforce type checking for a lambda/callable object. Automatically adds a function signature, and doc using dunder/magic methods! (Python 3.12+)
from collections.abc import Callable
from functools import wraps
class typed[T, U]:
"""
A decorator class that enforces type checking for callable objects.
This class takes in a callable and ensures that the arguments passed to the callable
match the specified types. It returns the callable with the same type signature, but
without any additional runtime type enforcement.
Example Usage:
--------------
::
>>> add1point5 = typed(int, float, lambda x: x + 1.5) # (int) -> float
>>> add1point5(42)
43.5
>>> concat = typed(str, str, str, lambda s1, s2: s1 + s2) # (str, str) -> str
>>> concat("foo", "bar")
'foobar'
"""
def __new__(cls, f: Callable[..., U], *Ts: type) -> Callable[..., U]:
"""
Creates a new instance of the typed class.
Parameters::
-----------
- ``f``: A callable that takes arguments of types `*Ts` and returns a value.
- ``Ts``: The types of the arguments that the callable accepts.
Returns::
--------
- A callable with the same signature as `f`.
Example Usage:
--------------
::
>>> add1point5 = typed(int, float, lambda x: x + 1.5)
>>> add1point5(42)
43.5
"""
@wraps(f)
def wrapped_function(*args: T) -> U:
return f(*args)
# construct the function signature and annotations manually
arg_annotations: dict[str, str] = {
f"arg{i}": t.__name__ for i, t in enumerate(Ts[:-1])
}
return_annotation: str = Ts[-1].__name__
wrapped_function.__annotations__ = {
**arg_annotations,
"return": return_annotation,
}
signature_str: str = (
f"({', '.join(arg_annotations.values())}) -> {return_annotation}"
)
wrapped_function.__doc__ = f"def {f.__name__}{signature_str}"
return wrapped_function
if __name__ == "__main__":
add1point5: Callable[[int], float] = typed(
lambda x: x + 1.5, int, float
) # (int) -> float
print(add1point5.__doc__) # def <lambda>(int) -> float
print(add1point5.__annotations__) # {'arg0': 'int', 'return': 'float'}
print(add1point5(42), end="\n\n") # 43.5
concat: Callable[[str, str], str] = typed(
lambda s1, s2: s1 + s2, str, str, str
) # (str, str) -> str
print(concat.__doc__) # def <lambda>(str, str) -> str
print(concat.__annotations__) # {'arg0': 'str', 'arg1': 'str', 'return': 'str'}
print(concat("foo", "bar")) # foobar
def <lambda>(int) -> float
{'arg0': 'int', 'return': 'float'}
43.5
def <lambda>(str, str) -> str
{'arg0': 'str', 'arg1': 'str', 'return': 'str'}
foobar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment