Last active
July 18, 2021 22:44
-
-
Save iscgar/731dfae7a6fbc26c9375624af1e3712a to your computer and use it in GitHub Desktop.
Helper decorator to enforce strict optional args for python fire
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Originally by Tyler Rhodes at https://gist.github.com/trhodeos/5a20b438480c880f7e15f08987bd9c0f | |
# Was broken in later fire versions due to unwrapping of decorated | |
# functions by fire, which ignores the signature of the wrapper and | |
# just doesn't pass the needed parameters, so I reimplemented this | |
# without functools.wraps() and added a bit more validation. | |
# Unfortunately this solution adds unneeded things to the help output | |
# due to the added varargs and kwargs, but there's no way around this | |
# that I know of. | |
import fire | |
import inspect | |
import operator | |
def only_allow_defined_args(wrapped): | |
"""Decorator which only allows arguments defined to be used. | |
Note, we need to specify this, as Fire allows method chaining. This means | |
that extra kwargs are kept around and passed to future methods that are | |
called. We don't need this, and should fail early if this happens. | |
Args: | |
wrapped: Function which to decorate. | |
Returns: | |
Wrapped function. | |
""" | |
def wrapper(*_, **kwargs): | |
argspec = inspect.getfullargspec(wrapped) | |
positional = {n: i for i, n in enumerate(argspec.args)} | |
possible_kwargs = set(argspec.kwonlyargs) | |
if positional and argspec.defaults: | |
possible_kwargs.update(argspec.args[-len(argspec.defaults):]) | |
positional_left = list(_) | |
unknown_args = [] | |
for name in kwargs: | |
idx = positional.get(name) | |
if idx is not None: | |
positional_left.pop(idx) | |
del positional[name] | |
for k, v in positional.items(): | |
if v > idx: | |
positional[k] -= 1 | |
elif name in possible_kwargs: | |
possible_kwargs.remove(name) | |
else: | |
unknown_args.append(name) | |
positional_fulfilled = list(reversed(sorted( | |
positional.items(), key=operator.itemgetter(1)))) | |
for k, idx in positional_fulfilled: | |
try: | |
positional_left.pop(idx) | |
except IndexError: | |
pass | |
else: | |
del positional[k] | |
if unknown_args and not argspec.varkw: | |
possible_kwargs.update(positional.keys()) | |
msg = 'Unknown arguments {}'.format(unknown_args) | |
if possible_kwargs: | |
msg += ', expected: {}'.format(list(possible_kwargs)) | |
raise fire.core.FireError(msg) | |
if positional_left and not argspec.varargs: | |
raise fire.core.FireError( | |
'Extraneous arguments specified: {}'.format(positional_left)) | |
return wrapped(*_, **kwargs) | |
for attr in ('__module__', '__name__', '__qualname__', '__doc__', '__annotations__'): | |
try: | |
setattr(wrapper, attr, getattr(wrapped, attr)) | |
except AttributeError: | |
pass | |
getattr(wrapper, '__dict__').update(getattr(wrapped, '__dict__', {})) | |
wrapper_params = list(inspect.signature(wrapper).parameters.values()) | |
POS = { | |
inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD} | |
sig = inspect.signature(wrapped) | |
parameters = list(sig.parameters.values()) | |
varargs_pos = len(parameters) | |
for i, p in enumerate(parameters): | |
if p.kind in POS: | |
varargs_pos = i + 1 | |
elif p.kind == inspect.Parameter.VAR_POSITIONAL: | |
varargs_pos = i | |
if (varargs_pos >= len(parameters) or | |
parameters[varargs_pos].kind != inspect.Parameter.VAR_POSITIONAL): | |
parameters.append(wrapper_params[0]) | |
if parameters[-1].kind != inspect.Parameter.VAR_KEYWORD: | |
parameters.append(wrapper_params[1]) | |
setattr(wrapper, '__signature__', inspect.Signature( | |
parameters, return_annotation=sig.return_annotation)) | |
return wrapper |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment