Skip to content

Instantly share code, notes, and snippets.

@dk949
Created September 13, 2024 20:23
Show Gist options
  • Save dk949/acf1ce2ccea035c86b2a7ed71fbc6b66 to your computer and use it in GitHub Desktop.
Save dk949/acf1ce2ccea035c86b2a7ed71fbc6b66 to your computer and use it in GitHub Desktop.
Python decorator to do runtime typechecking based on function type annotations. See docstring for detaills
def typecheck(
func=None,
*,
check_return=False,
):
"""
Perform a runtime type check of the function arguments and (optionally) its
return type.
All annotated arguments and keyword arguments will be typechecked. Known
container types (list, set, frozenset, dict) will be typechecked recursively
if annotated as a generic alias (i.e `list[int]` rather than `list`).
If `check_return` is True (default False), the returned value will also be
checked. If return value is not annotated and `check_return` is True, a
`TypeError` will be raised.
Usage (doubles as a doctest 🙂)
>>> class Foo:
... @typecheck
... def foo(self, i: int, j):
... return "passed"
...
>>> @typecheck(check_return=True)
... def bar(
... first: list[int | str],
... /,
... x: str | Foo,
... y: dict[float, float],
... *,
... z: set[Foo | int | tuple[str, str]],
... ) -> str:
... return "passed"
...
>>> f = Foo()
>>> f.foo(1, None) # First arg should be int, second can be anything
'passed'
>>> bar([1, "2"], "string", y={1.2: 2.3}, z={Foo(), 23, ("a", "b")})
'passed'
>>> f.foo("1", "hello")
Traceback (most recent call last):
...
TypeError: Argument i must be of type int, got str
>>> bar([1, "2"], 42, y={1.2: 2.3}, z={Foo(), 23, ("a", "b")})
Traceback (most recent call last):
...
TypeError: Argument x must be of type str | Foo, got int
"""
from functools import wraps, partial
from types import GenericAlias, UnionType
from typing import Any, get_args, get_origin
from inspect import getfullargspec
from itertools import chain
# type | UnionType | GenericAlias
if func is None:
return partial(
typecheck,
check_return=check_return,
)
arg_names, *_, annotations = getfullargspec(func)
if check_return and "return" not in annotations:
raise TypeError(
"Cannot specify `check_return=True` without annotated return type"
)
def check_type(value, expected_type: type | UnionType | GenericAlias):
# Recursively checks if the value matches the expected type.
if isinstance(expected_type, GenericAlias):
origin_type = get_origin(expected_type)
if not isinstance(value, origin_type):
return False
elif origin_type is list or origin_type is set or origin_type is frozenset:
item_type = get_args(expected_type)[0]
return all(check_type(item, item_type) for item in value)
elif origin_type is tuple:
item_types = get_args(expected_type)
return len(value) == len(item_types) and all(
check_type(item, ty) for item, ty in zip(value, item_types)
)
elif origin_type is dict:
key_type, value_type = get_args(expected_type)
return all(
check_type(k, key_type) and check_type(v, value_type)
for k, v in value.items()
)
else:
raise TypeError(f"Cannot check generic alias {expected_type}")
elif isinstance(expected_type, UnionType):
sub_types = get_args(expected_type)
return any(check_type(value, t) for t in sub_types)
else:
return isinstance(value, expected_type)
def get_type_name(t: Any) -> str:
match t:
case type():
return t.__name__
case GenericAlias(
__origin__=origin
) if origin is list or origin is set or origin is frozenset:
return f"{get_type_name(origin)}[{' | '.join(set(get_type_name(a) for a in get_args(t)))}]"
case GenericAlias(__origin__=origin) if origin is tuple:
return f"{get_type_name(origin)}[{', '.join(get_type_name(a) for a in get_args(t))}]"
case GenericAlias():
return str(t)
case UnionType():
return " | ".join(get_type_name(a) for a in get_args(t))
case list() | set() | frozenset():
return f"{type(t).__name__}[{' | '.join(set(get_type_name(i) for i in t))}]"
case tuple():
return f"tuple[{', '.join(get_type_name(type(i)) for i in t)}]"
case dict():
key_t = " | ".join(set(get_type_name(type(k)) for k in t.keys()))
val_t = " | ".join(set(get_type_name(type(v)) for v in t.values()))
return f"dict[{key_t}, {val_t}]"
case _:
return get_type_name(type(t))
def raise_if_no_match(
arg: Any, name: str, expected: type | UnionType | GenericAlias
):
if not check_type(arg, expected):
raise TypeError(
f"{'Return' if name == "return" else f'Argument {name}'} must be of type {get_type_name(expected)}, got {get_type_name(arg)}"
)
@wraps(func)
def wrapper(*args, **kwargs):
for name, arg in chain(zip(arg_names, args), kwargs.items()):
if name in annotations:
raise_if_no_match(arg, name, annotations[name])
ret = func(*args, **kwargs)
if check_return:
raise_if_no_match(ret, "return", annotations["return"])
return ret
return wrapper
if __name__ == "__main__":
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment