Created
February 21, 2022 17:56
-
-
Save Kautenja/760d69e2c871a814b0ed75ea17a6ec85 to your computer and use it in GitHub Desktop.
A wrapper for enforcing static typed inputs to python methods.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A python method wrapper to enforce static typing.""" | |
import typing | |
from functools import wraps | |
import inspect | |
def static_type(function): | |
""" | |
Wrap a method with a guard condition to statically type input variables. | |
Args: | |
function: An arbitrary method to static type check inputs to. | |
Returns: | |
The method wrapped to statically check input types. | |
Notes: | |
This is a decorator method. The expected usage is: | |
```python | |
@static_type | |
def foo(*args, **kwargs): | |
pass | |
``` | |
""" | |
# Create a dictionary that maps instance variables by name to | |
# their associated type hints. | |
hint_map = typing.get_type_hints(function) | |
# Create a list of the input arguments in order of definition. | |
arg_list = inspect.getfullargspec(function).args | |
# Determine if the method is an instance method or a class | |
# method. There is no need to check the type of the self or cls. | |
arg_start = 0 | |
if arg_list[0] in {'self', 'cls'}: | |
arg_list = arg_list[1:] | |
arg_start = 1 | |
# Create a sorted list of type hints according to order of | |
# definition in the function signature. | |
hints = list(map(hint_map.__getitem__, arg_list)) | |
# It looks confusing to decorate this internal wrapper function, but this | |
# decorator helps ensure the function name, documentation, etc. propagate | |
# from the input function to the execute function. | |
@wraps(function) | |
def execute(*args, **kwargs): | |
# Iterate over the positional arguments and check the types. | |
for (arg, hint, name) in zip(args[arg_start:], hints[:len(args)], arg_list[:len(args)]): | |
if not isinstance(arg, hint): | |
raise TypeError(f'expected argument {repr(name)} to be of type {hint.__name__}, but it is an instance of {type(arg).__name__}!') | |
# Iterate over the keyword arguments and check the types. | |
for (name, arg) in kwargs.items(): | |
hint = hint_map[name] | |
if not isinstance(arg, hint): | |
raise TypeError(f'expected argument {repr(name)} to be of type {hint.__name__}, but it is an instance of {type(arg).__name__}!') | |
# At this point all types have been checked, call the function. | |
return function(*args, **kwargs) | |
return execute | |
@static_type | |
def foo(a: int, b: float, opt: int=1) -> None: | |
return | |
class Foo: | |
@static_type | |
def __init__(self, a: int, b: float, opt: int=1) -> None: | |
return | |
@static_type | |
@classmethod | |
def class_bar(cls, a: int, b: float, opt: int=1) -> None: | |
return | |
@static_type | |
@staticmethod | |
def static_bar(a: int, b: float, opt: int=1) -> None: | |
return | |
foo(1, 1.0) | |
foo('asdf', 1.0) | |
foo(1, 'asdf') | |
foo(1, 1.0, 2) | |
foo(1, 1.0, 2.0) | |
foo(1, 1.0, opt=2) | |
foo(1, 1.0, opt=2.0) | |
Foo(1, 1.0) | |
Foo('asdf', 1.0) | |
Foo(1, 'asdf') | |
Foo(1, 1.0, 2) | |
Foo(1, 1.0, 2.0) | |
Foo(1, 1.0, opt=2) | |
Foo(1, 1.0, opt=2.0) | |
Foo.class_bar(1, 1.0) | |
Foo.class_bar('asdf', 1.0) | |
Foo.class_bar(1, 'asdf') | |
Foo.class_bar(1, 1.0, 2) | |
Foo.class_bar(1, 1.0, 2.0) | |
Foo.class_bar(1, 1.0, opt=2) | |
Foo.class_bar(1, 1.0, opt=2.0) | |
Foo.static_bar(1, 1.0) | |
Foo.static_bar('asdf', 1.0) | |
Foo.static_bar(1, 'asdf') | |
Foo.static_bar(1, 1.0, 2) | |
Foo.static_bar(1, 1.0, 2.0) | |
Foo.static_bar(1, 1.0, opt=2) | |
Foo.static_bar(1, 1.0, opt=2.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment