Created
February 12, 2017 01:36
-
-
Save michaelbartnett/ed189ceab0d5d85973c6964d0c63e5d2 to your computer and use it in GitHub Desktop.
wrote this because mypy kinda sucks right now and Matthias Felleisen makes a convincing argument that runtime checks are necessary anyway
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import inspect | |
import marshmallow | |
class ParamConverter: | |
__slots__ = ('converter_func',) | |
def __init__(self, converter): | |
if not callable(converter): | |
raise TypeError('converter must be callable') | |
self.converter_func = converter | |
def __call__(self, val): | |
return self.converter_func(val) | |
class MarshmallowArgChecker: | |
__slots__ = ('schema_instance',) | |
def __init__(self, schema): | |
self.schema_instance = schema(strict=True) | |
def __call__(self, argname, val): | |
try: | |
result = self.schema_instance.load(val) | |
except marshmallow.ValidationError as exc: | |
raise TypeError('value passed for argument {}, {}, failed validation by the Marshmallow schema {}' | |
''.format(argname, val, self.converter_func)) from exc | |
return result | |
class TypeArgChecker: | |
__slots__ = ('required_type',) | |
def __init__(self, required_type): | |
self.required_type = required_type | |
def __call__(self, argname, val): | |
if not isinstance(val, self.required_type): | |
raise TypeError('value passed for argument {}, {}, was not an instance of {}' | |
''.format(argname, val, self.required_type)) | |
return val | |
class ParamConverterArgChecker: | |
__slots__ = ('converter',) | |
def __init__(self, converter): | |
self.converter = converter | |
def __call__(self, argname, val): | |
try: | |
self.converter(val) | |
except TypeError as exc: | |
raise TypeError('value passed for argument {}, {}, raised error in the type converter {}' | |
''.format(argname, val, self.converter)) from exc | |
return val | |
class ValueEqualArgChecker: | |
__slots__ = ('required_value',) | |
def __init__(self, required_value): | |
self.required_value = required_value | |
def __call__(self, argname, val): | |
if val != self.required_value: | |
raise TypeError('value passed for argument {}, {}, was not the required value {}' | |
''.format(argname, val, self.required_value)) | |
return val | |
def _checker_for(annotation): | |
if isinstance(annotation, ParamConverter): | |
return ParamConverterArgChecker(annotation) | |
elif isinstance(annotation, type): | |
return TypeArgChecker(annotation) | |
elif (annotation is None or | |
annotation == [] or | |
annotation == {} or | |
annotation == '' or | |
isinstance(annotation, bool)): | |
return ValueEqualArgChecker(annotation) | |
else: | |
raise TypeError('annotation must be a type or a ParamConverter') | |
_g_checker_cache = {} | |
_g_identity_checker = lambda argname, val: val | |
def checkfunc(enabled=True): | |
if not enabled: | |
def disabled_decorator(f): | |
return f | |
return disabled_decorator | |
def enabled_decorator(f): | |
sig = inspect.signature(f) | |
_checker_arg_dict = {} | |
for paramname, param in sig.parameters.items(): | |
picked_checker = _g_checker_cache.get(param.annotation, None) | |
if picked_checker is None: | |
picked_checker = _checker_for(param.annotation) | |
_g_checker_cache[param.annotation] = picked_checker | |
_checker_arg_dict[paramname] = picked_checker | |
if sig.return_annotation is sig.empty: | |
return_checker = _g_identity_checker | |
else: | |
return_checker = _g_checker_cache.get(sig.return_annotation, None) | |
if return_checker is None: | |
return_checker = _checker_for(sig.return_annotation) | |
_g_checker_cache[sig.return_annotation] = return_checker | |
@functools.wraps(f) | |
def wrapped(*args, **kwargs): | |
bound = sig.bind(*args, **kwargs) | |
for argname, argval in bound.arguments.items(): | |
bound.arguments[argname] = _checker_arg_dict[argname](argname, argval) | |
return return_checker(f.__qualname__ + "'s return value", f(*bound.args, **bound.kwargs)) | |
return wrapped | |
return enabled_decorator | |
if __name__ == '__main__': | |
@checkfunc() | |
def test1(pid: int, x: str = 'foo', *, portal: str) -> None: | |
print('pid:', pid, 'x:', x, 'portal:', portal) | |
@checkfunc(True) | |
def test2(pid: int, x: str = 'foo', *, portal: str) -> None: | |
print('pid:', pid, 'x:', x, 'portal:', portal) | |
return pid | |
@checkfunc() | |
def test3(pid: int, x: str = 'foo', *, portal: str): | |
print('pid:', pid, 'x:', x, 'portal:', portal) | |
return pid | |
@checkfunc(False) | |
def test4(pid, portal) -> ParamConverter: | |
pass | |
def _expect_throw(shouldthrow, callthis): | |
try: | |
callthis() | |
return not shouldthrow | |
except TypeError: | |
return shouldthrow | |
return not shouldthrow | |
results = [ | |
_expect_throw(False, lambda: test1(42, portal='99')), | |
_expect_throw(False, lambda: test1(42, '24', portal='99')), | |
_expect_throw(True, lambda: test1('42', portal='99')), | |
_expect_throw(True, lambda: test1('42', 'nope')), | |
_expect_throw(True, lambda: test1('42', 'nope', 'alsonope')), | |
_expect_throw(True, lambda: test2(42, portal='99')), | |
_expect_throw(False, lambda: test3(42, portal='99')), | |
_expect_throw(False, lambda: test4(42, portal='99')), | |
] | |
if not all(results): | |
print('Some tests failed: ', results) | |
else: | |
print('All tests passed') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment