Last active
April 15, 2020 06:20
-
-
Save Phxntxm/eeed98262ac876206fde6481c5dd5998 to your computer and use it in GitHub Desktop.
Overloaded methods
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import traceback | |
import inspect | |
from copy import copy | |
overloaded_functions = {} | |
def overload(func): | |
# Create a mapping of overloaded functions, based on the name of the function | |
name = func.__name__ | |
if name not in overloaded_functions: | |
overloaded_functions[name] = [] | |
overloaded_functions[name].append(func) | |
# The wrapper that will be called with the arguments | |
def wrapper(*args, **kwargs): | |
def is_valid(func): | |
# This is the signature we'll use for detecting the right function based on arguments given | |
signature = inspect.signature(func) | |
try: | |
# Try to bind with the arguments given | |
# if this doesn't error then the func can be called with these arguments | |
bound = signature.bind(*args, **kwargs) | |
# If it errors, then it's not possible to call the function with the args/kwargs given | |
except TypeError: | |
return False | |
# If it succeeds...now we want to type check | |
else: | |
# Apply the defaults to ensure the arguments we're checking are the "final" arguments | |
bound.apply_defaults() | |
# Loop through the signature's parameters, these are what we're checking the given against | |
for i, (arg, parameter) in enumerate(signature.parameters.items()): | |
# Get the argument given | |
given_argument = bound.arguments.get(arg) | |
# If there's no annotation, we don't care what type was given | |
if parameter.annotation is parameter.empty: | |
continue | |
# If the bound argument is the default, we don't care if it matches the type or not | |
# (think of a string/dict argument defaulting to None, to ensure it's immutable) | |
elif given_argument == parameter.default: | |
continue | |
# Otherwise check against our given argument | |
elif type(given_argument) is parameter.annotation: | |
continue | |
# If nothing matches....then what was given doesn't match the annotation | |
else: | |
return False | |
return True | |
# Loop through to find the first matching function | |
for func in overloaded_functions[wrapper.__name__]: | |
if is_valid(func): | |
return func(*args, **kwargs) | |
raise TypeError(f"Could not find an overloaded function matching the arguments given: {args}, {kwargs}") | |
# Create a copy of the wrapper with the name of the function given | |
# This can then be checked against the function being called later, so that we handle functions with different names | |
overloaded = copy(wrapper) | |
overloaded.__name__ = name | |
return wrapper |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment