Created
September 13, 2024 20:23
-
-
Save dk949/acf1ce2ccea035c86b2a7ed71fbc6b66 to your computer and use it in GitHub Desktop.
Python decorator to do runtime typechecking based on function type annotations. See docstring for detaills
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def typecheck( | |
func=None, | |
*, | |
check_return=False, | |
): | |
""" | |
Perform a runtime type check of the function arguments and (optionally) its | |
return type. | |
All annotated arguments and keyword arguments will be typechecked. Known | |
container types (list, set, frozenset, dict) will be typechecked recursively | |
if annotated as a generic alias (i.e `list[int]` rather than `list`). | |
If `check_return` is True (default False), the returned value will also be | |
checked. If return value is not annotated and `check_return` is True, a | |
`TypeError` will be raised. | |
Usage (doubles as a doctest 🙂) | |
>>> class Foo: | |
... @typecheck | |
... def foo(self, i: int, j): | |
... return "passed" | |
... | |
>>> @typecheck(check_return=True) | |
... def bar( | |
... first: list[int | str], | |
... /, | |
... x: str | Foo, | |
... y: dict[float, float], | |
... *, | |
... z: set[Foo | int | tuple[str, str]], | |
... ) -> str: | |
... return "passed" | |
... | |
>>> f = Foo() | |
>>> f.foo(1, None) # First arg should be int, second can be anything | |
'passed' | |
>>> bar([1, "2"], "string", y={1.2: 2.3}, z={Foo(), 23, ("a", "b")}) | |
'passed' | |
>>> f.foo("1", "hello") | |
Traceback (most recent call last): | |
... | |
TypeError: Argument i must be of type int, got str | |
>>> bar([1, "2"], 42, y={1.2: 2.3}, z={Foo(), 23, ("a", "b")}) | |
Traceback (most recent call last): | |
... | |
TypeError: Argument x must be of type str | Foo, got int | |
""" | |
from functools import wraps, partial | |
from types import GenericAlias, UnionType | |
from typing import Any, get_args, get_origin | |
from inspect import getfullargspec | |
from itertools import chain | |
# type | UnionType | GenericAlias | |
if func is None: | |
return partial( | |
typecheck, | |
check_return=check_return, | |
) | |
arg_names, *_, annotations = getfullargspec(func) | |
if check_return and "return" not in annotations: | |
raise TypeError( | |
"Cannot specify `check_return=True` without annotated return type" | |
) | |
def check_type(value, expected_type: type | UnionType | GenericAlias): | |
# Recursively checks if the value matches the expected type. | |
if isinstance(expected_type, GenericAlias): | |
origin_type = get_origin(expected_type) | |
if not isinstance(value, origin_type): | |
return False | |
elif origin_type is list or origin_type is set or origin_type is frozenset: | |
item_type = get_args(expected_type)[0] | |
return all(check_type(item, item_type) for item in value) | |
elif origin_type is tuple: | |
item_types = get_args(expected_type) | |
return len(value) == len(item_types) and all( | |
check_type(item, ty) for item, ty in zip(value, item_types) | |
) | |
elif origin_type is dict: | |
key_type, value_type = get_args(expected_type) | |
return all( | |
check_type(k, key_type) and check_type(v, value_type) | |
for k, v in value.items() | |
) | |
else: | |
raise TypeError(f"Cannot check generic alias {expected_type}") | |
elif isinstance(expected_type, UnionType): | |
sub_types = get_args(expected_type) | |
return any(check_type(value, t) for t in sub_types) | |
else: | |
return isinstance(value, expected_type) | |
def get_type_name(t: Any) -> str: | |
match t: | |
case type(): | |
return t.__name__ | |
case GenericAlias( | |
__origin__=origin | |
) if origin is list or origin is set or origin is frozenset: | |
return f"{get_type_name(origin)}[{' | '.join(set(get_type_name(a) for a in get_args(t)))}]" | |
case GenericAlias(__origin__=origin) if origin is tuple: | |
return f"{get_type_name(origin)}[{', '.join(get_type_name(a) for a in get_args(t))}]" | |
case GenericAlias(): | |
return str(t) | |
case UnionType(): | |
return " | ".join(get_type_name(a) for a in get_args(t)) | |
case list() | set() | frozenset(): | |
return f"{type(t).__name__}[{' | '.join(set(get_type_name(i) for i in t))}]" | |
case tuple(): | |
return f"tuple[{', '.join(get_type_name(type(i)) for i in t)}]" | |
case dict(): | |
key_t = " | ".join(set(get_type_name(type(k)) for k in t.keys())) | |
val_t = " | ".join(set(get_type_name(type(v)) for v in t.values())) | |
return f"dict[{key_t}, {val_t}]" | |
case _: | |
return get_type_name(type(t)) | |
def raise_if_no_match( | |
arg: Any, name: str, expected: type | UnionType | GenericAlias | |
): | |
if not check_type(arg, expected): | |
raise TypeError( | |
f"{'Return' if name == "return" else f'Argument {name}'} must be of type {get_type_name(expected)}, got {get_type_name(arg)}" | |
) | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
for name, arg in chain(zip(arg_names, args), kwargs.items()): | |
if name in annotations: | |
raise_if_no_match(arg, name, annotations[name]) | |
ret = func(*args, **kwargs) | |
if check_return: | |
raise_if_no_match(ret, "return", annotations["return"]) | |
return ret | |
return wrapper | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment