Last active
December 19, 2018 22:42
-
-
Save supposedly/467e737a16bb96a09940d084a8ac2102 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
from functools import wraps | |
from itertools import islice, starmap | |
def _callable(obj): | |
return callable(obj) and obj is not inspect._empty | |
def convert(hint, val): | |
return hint(val) if _callable(hint) else val | |
def typecast(func): | |
""" | |
Wraps func such that arguments passed to it will be converted | |
according to its typehints. | |
More specifically, calls func's annotations on arguments | |
passed to it; non-callable annotations are not touched. | |
If a callable annotates a variadic argument (*, **), | |
the annotation will be called on each value therein. | |
""" | |
def _hint_for(param): | |
return func.__annotations__.get(param.name) | |
params = inspect.signature(func).parameters.values() | |
# Gather annotations | |
# ...of positional parameters | |
pos = [_hint_for(p) for p in params if p.kind < VAR_POSITIONAL] | |
var_pos = next((_hint_for(p) for p in params if p.kind == VAR_POSITIONAL), None) | |
pos_defaults = [p.default for p in params if p.kind < VAR_POSITIONAL] | |
# ...of keyword parameters | |
kw = {p.name: _hint_for(p) for p in params if p.kind == KEYWORD_ONLY} | |
var_kw = next((_hint_for(p) for p in params if p.kind > KEYWORD_ONLY), None) | |
kw_defaults = {p.name: p.default for p in params if p.kind == KEYWORD_ONLY} | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
args_, kwargs_ = [], {} | |
# Can use a consumable generator to keep track of what | |
# positionals are left to convert | |
arg_iter = iter(args) | |
if len(args) > len(pos) and not var_pos: | |
# More positional arguments were passed than func accepts | |
func(*args, **kwargs) # raise TypeError | |
# Type-convert the positional arguments that were passed as such | |
args_.extend(starmap(convert, zip(pos, arg_iter))) | |
# Fill in the rest with either positional parameters passed as kwargs | |
# or, failing that, each parameter's default value | |
for param, hint, default in islice(zip(params, pos, pos_defaults), len(args_), None): | |
if param.name in kwargs: | |
args_.append(convert(hint, kwargs.pop(param.name))) | |
else: | |
args_.append(default) | |
# If some positionals aren't present and also don't have defaults, | |
if inspect._empty in args_: | |
# Then they were simply not passed as positionals, | |
# but they may have been passed via keyword: | |
for idx, (param, hint, passed) in enumerate(zip(params, pos, args_)): | |
if passed is not inspect._empty: | |
# Only look at those for which nothing was passed | |
continue | |
try: | |
args_[idx] = convert(hint, kwargs.pop(param.name)) | |
except KeyError: | |
# Then this parameter wasn't given, period | |
func(*args, **kwargs) # raise TypeError | |
# If func accepts *args and arg_iter has any values left in it, they | |
# should be passed to *args | |
if var_pos is not None: | |
args_.extend(map(var_pos, arg_iter) if _callable(var_pos) else arg_iter) | |
# Keyword-parameter typehints: | |
for name, hint in kw.items(): | |
try: | |
kwargs_[name] = convert(hint, kwargs[name]) | |
except KeyError: | |
default = kw_defaults[name] | |
if default is inspect._empty: | |
# Keyword argument was not passed and has no default | |
func(*args, **kwargs) # raise TypeError | |
kwargs_[name] = default | |
# **kwargs: just convert every value while keeping the dict otherwise intact | |
if var_kw is not None: | |
kwargs_.update({name: convert(var_kw, val) for name, val in kwargs.items() if name not in kwargs_}) | |
return func(*args_, **kwargs_) | |
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@typecast | |
def test(thing: str, other: 'not callable', *, one, two: int): | |
""" | |
>>> test(1, [], one='1', two='2') | |
('1', <class 'str'>, [], <class 'list'>, '1', <class 'str'>, 2, <class 'int'>) | |
""" | |
return thing, type(thing), other, type(other), one, type(one), two, type(two) | |
#------------------------------------------------------------# | |
@typecast | |
def test(one: float, *args: str, two: str = None, **kwargs: 'not callable') -> list: | |
""" | |
>>> test(1, 2, 3, three='5') | |
[1.0, <class 'float'>, ('2', '3'), [<class 'str'>, <class 'str'>], None, <class 'NoneType'>, {'three': ('5', <class 'str'>)}] | |
""" | |
return one, type(one), args, [type(i) for i in args], two, type(two), {i: (kwargs[i], type(kwargs[i])) for i in kwargs} | |
#------------------------------------------------------------# | |
@typecast | |
def test(one, *args, two, **kwargs) -> str: | |
""" | |
>>> test(1, 2, two=3, three='5') | |
"(1, <class 'int'>, (2,), [<class 'int'>], 3, <class 'int'>, {'three': ('5', <class 'str'>)})" | |
""" | |
return one, type(one), args, [type(i) for i in args], two, type(two), {i: (kwargs[i], type(kwargs[i])) for i in kwargs} | |
#------------------------------------------------------------# | |
@typecast | |
def test(one: float, *, two: float, **kwargs: float): | |
""" | |
>>> test(1, two=3, three='5') | |
(1.0, <class 'float'>, 3.0, <class 'float'>, {'three': (5.0, <class 'float'>)}) | |
""" | |
return one, type(one), two, type(two), {i: (kwargs[i], type(kwargs[i])) for i in kwargs} | |
#------------------------------------------------------------# | |
### DON'T USE THESE ONES IN ANY SERIOUS MANNER EVER ### | |
def TRUNC(num): return lambda seq: seq[:num] | |
@typecast | |
def truncate_args(one: TRUNC(2), two: TRUNC(3)): | |
""" | |
>>> truncate_args('abcd', 'efgh') | |
('ab', 'efg') | |
""" | |
return one, two | |
#------------------------------------------------------------# | |
def EQ_LEN(seq, *, l=[]): | |
l.append(len(seq)) | |
if len(l) < 2: # 2 == no. args | |
return seq | |
try: | |
assert all(i == l[0] for i in l) | |
except AssertionError: | |
raise ValueError('args must be of equal length') | |
else: | |
return seq | |
finally: | |
l.clear() | |
@typecast | |
def what(arg1: EQ_LEN, arg2: EQ_LEN): | |
""" | |
>>> what('abcd', 'abcd') | |
(4, 4) | |
>>> what('abcd', 'abc') | |
Traceback (most recent call last): | |
File "<stdin>", line 6, in EQ_LEN | |
ValueError: args must be of equal length | |
>>> | |
""" | |
return len(arg1), len(arg2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment