Skip to content

Instantly share code, notes, and snippets.

@michaelbartnett
Created February 12, 2017 01:36
Show Gist options
  • Save michaelbartnett/ed189ceab0d5d85973c6964d0c63e5d2 to your computer and use it in GitHub Desktop.
Save michaelbartnett/ed189ceab0d5d85973c6964d0c63e5d2 to your computer and use it in GitHub Desktop.
wrote this because mypy kinda sucks right now and Matthias Felleisen makes a convincing argument that runtime checks are necessary anyway
import functools
import inspect
import marshmallow
class ParamConverter:
__slots__ = ('converter_func',)
def __init__(self, converter):
if not callable(converter):
raise TypeError('converter must be callable')
self.converter_func = converter
def __call__(self, val):
return self.converter_func(val)
class MarshmallowArgChecker:
__slots__ = ('schema_instance',)
def __init__(self, schema):
self.schema_instance = schema(strict=True)
def __call__(self, argname, val):
try:
result = self.schema_instance.load(val)
except marshmallow.ValidationError as exc:
raise TypeError('value passed for argument {}, {}, failed validation by the Marshmallow schema {}'
''.format(argname, val, self.converter_func)) from exc
return result
class TypeArgChecker:
__slots__ = ('required_type',)
def __init__(self, required_type):
self.required_type = required_type
def __call__(self, argname, val):
if not isinstance(val, self.required_type):
raise TypeError('value passed for argument {}, {}, was not an instance of {}'
''.format(argname, val, self.required_type))
return val
class ParamConverterArgChecker:
__slots__ = ('converter',)
def __init__(self, converter):
self.converter = converter
def __call__(self, argname, val):
try:
self.converter(val)
except TypeError as exc:
raise TypeError('value passed for argument {}, {}, raised error in the type converter {}'
''.format(argname, val, self.converter)) from exc
return val
class ValueEqualArgChecker:
__slots__ = ('required_value',)
def __init__(self, required_value):
self.required_value = required_value
def __call__(self, argname, val):
if val != self.required_value:
raise TypeError('value passed for argument {}, {}, was not the required value {}'
''.format(argname, val, self.required_value))
return val
def _checker_for(annotation):
if isinstance(annotation, ParamConverter):
return ParamConverterArgChecker(annotation)
elif isinstance(annotation, type):
return TypeArgChecker(annotation)
elif (annotation is None or
annotation == [] or
annotation == {} or
annotation == '' or
isinstance(annotation, bool)):
return ValueEqualArgChecker(annotation)
else:
raise TypeError('annotation must be a type or a ParamConverter')
_g_checker_cache = {}
_g_identity_checker = lambda argname, val: val
def checkfunc(enabled=True):
if not enabled:
def disabled_decorator(f):
return f
return disabled_decorator
def enabled_decorator(f):
sig = inspect.signature(f)
_checker_arg_dict = {}
for paramname, param in sig.parameters.items():
picked_checker = _g_checker_cache.get(param.annotation, None)
if picked_checker is None:
picked_checker = _checker_for(param.annotation)
_g_checker_cache[param.annotation] = picked_checker
_checker_arg_dict[paramname] = picked_checker
if sig.return_annotation is sig.empty:
return_checker = _g_identity_checker
else:
return_checker = _g_checker_cache.get(sig.return_annotation, None)
if return_checker is None:
return_checker = _checker_for(sig.return_annotation)
_g_checker_cache[sig.return_annotation] = return_checker
@functools.wraps(f)
def wrapped(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
for argname, argval in bound.arguments.items():
bound.arguments[argname] = _checker_arg_dict[argname](argname, argval)
return return_checker(f.__qualname__ + "'s return value", f(*bound.args, **bound.kwargs))
return wrapped
return enabled_decorator
if __name__ == '__main__':
@checkfunc()
def test1(pid: int, x: str = 'foo', *, portal: str) -> None:
print('pid:', pid, 'x:', x, 'portal:', portal)
@checkfunc(True)
def test2(pid: int, x: str = 'foo', *, portal: str) -> None:
print('pid:', pid, 'x:', x, 'portal:', portal)
return pid
@checkfunc()
def test3(pid: int, x: str = 'foo', *, portal: str):
print('pid:', pid, 'x:', x, 'portal:', portal)
return pid
@checkfunc(False)
def test4(pid, portal) -> ParamConverter:
pass
def _expect_throw(shouldthrow, callthis):
try:
callthis()
return not shouldthrow
except TypeError:
return shouldthrow
return not shouldthrow
results = [
_expect_throw(False, lambda: test1(42, portal='99')),
_expect_throw(False, lambda: test1(42, '24', portal='99')),
_expect_throw(True, lambda: test1('42', portal='99')),
_expect_throw(True, lambda: test1('42', 'nope')),
_expect_throw(True, lambda: test1('42', 'nope', 'alsonope')),
_expect_throw(True, lambda: test2(42, portal='99')),
_expect_throw(False, lambda: test3(42, portal='99')),
_expect_throw(False, lambda: test4(42, portal='99')),
]
if not all(results):
print('Some tests failed: ', results)
else:
print('All tests passed')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment