Last active
October 4, 2021 06:50
-
-
Save ItsDrike/afc168f5fa592cb50c82a8ade7218c7f to your computer and use it in GitHub Desktop.
Python Function Overloads
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import inspect | |
from typing import Hashable, Any, Optional, Callable, TypeVar | |
# Make a unique object for unambiguously distinguishing some missing values. | |
# (Simply using None wouldn't work, because the actual value could have been None) | |
_MISSING = object() | |
# Define a callable type var for the @overload decorator function, to satisfy | |
# static type checkers and return the same function as the passed one | |
T = TypeVar("T", bound=Callable) | |
class NoMatchingOverload(TypeError): | |
""" | |
This exception is raised when a method which has multiple overloads | |
was called without matching any of the argument overloads. | |
""" | |
class MultipleMatchingOverloads(TypeError): | |
""" | |
This exception is raised when a method which has multiple overloads | |
was called with arguments matching multiple of these overloads and | |
we weren't able to determine which of these overloads should be used. | |
""" | |
class OverloadList(list): | |
""" | |
Make a unique list class only for storing overload values. | |
Just using regular list wouldn't be sufficient because we need to | |
unambiguously distinguish it from regular lists. | |
""" | |
class OverloadDict(dict): | |
""" | |
A custom dictionary class that allows the existence of multiple | |
items under the same key if they have the __overload__ attribute | |
set to True. | |
""" | |
def __setitem__(self, key: Hashable, value: Any) -> None: | |
""" | |
Override the setitem method to handle for setting multiple | |
items with the same key if they have the value.__overload__ | |
attribute set to True. | |
""" | |
previous_value = self.get(key, _MISSING) | |
is_overloaded = getattr(value, "__overload__", False) | |
if previous_value is _MISSING: | |
insert_value = OverloadList([value]) if is_overloaded else value | |
super().__setitem__(key, insert_value) | |
elif isinstance(previous_value, OverloadList): | |
if not is_overloaded: | |
raise ValueError( | |
"Can't override existing overloaded value with " | |
"non-overloaded value (forgot @overload?)" | |
) | |
previous_value.append(value) | |
else: | |
if is_overloaded: | |
raise ValueError( | |
"Can't set override value for a key which already " | |
"contains non-overloaded value (forgot @overload?)" | |
) | |
super().__setitem__(key, value) | |
class BoundOverloadDispatcher: | |
""" | |
This class is the object in place of the overloaded functions, | |
when it's called we decide which overload to use based on the | |
arguments for that call. | |
""" | |
def __init__(self, instance: object, owner_cls: type[object], name: str, overload_list: OverloadList): | |
self.instance = instance | |
self.owner_cls = owner_cls | |
self.name = name | |
self.overload_list = overload_list | |
self.signatures = [inspect.signature(f) for f in overload_list] | |
def __call__(self, *args, **kwargs): | |
""" | |
Once the overloaded method is called, try find the function with | |
a signature that matches the passed call arguments. | |
- If multiple functions like these are found, raise MultipleMatchingOverloads. | |
- If we didn't find any functions that match the arguments, try to find | |
a next in line matching method with super(), if we don't find a matching | |
function there either, raise NoMatchingOverload. | |
""" | |
try: | |
f = self.best_match(*args, **kwargs) | |
except NoMatchingOverload: | |
pass | |
else: | |
return f(self.instance, *args, **kwargs) | |
# No matching overload was found in the owner class | |
# try to check the next in line | |
super_instance = super(self.owner_cls, self.instance) | |
super_call = getattr(super_instance, self.name, _MISSING) | |
if super_call is not _MISSING: | |
return super_call(*args, **kwargs) # type: ignore | |
else: | |
raise NoMatchingOverload() | |
def best_match(self, *args, **kwargs): | |
""" | |
Attempt to find the best overloaded method that matches given arguments. | |
If we find multiple methods that all match them, raise MultipleMatchingOverloads | |
and if we don't find any, raise NoMatchingOverload. | |
""" | |
matching_functions = [] | |
for f, sig in zip(self.overload_list, self.signatures): | |
try: | |
bound_args = sig.bind(self.instance, *args, **kwargs) | |
except TypeError: | |
pass # missing/extra/unexpected args or kwargs | |
else: | |
bound_args.apply_defaults() | |
if self._signature_matches(sig, bound_args): | |
matching_functions.append(f) | |
if len(matching_functions) == 0: | |
raise NoMatchingOverload() | |
if len(matching_functions) == 1: | |
return matching_functions[0] | |
raise MultipleMatchingOverloads() | |
@staticmethod | |
def _type_hint_matches(obj, hint): | |
""" | |
Check if the type hint matches the given object. | |
NOTE: This only works with concrete types, not something like Optional. | |
""" | |
return hint is inspect.Parameter.empty or isinstance(obj, hint) | |
@classmethod | |
def _signature_matches(cls, sig: inspect.Signature, bound_args: inspect.BoundArguments) -> bool: | |
"""Check if all of the type hints of the signature match the bound arguments.""" | |
for name, arg in bound_args.arguments.items(): | |
param = sig.parameters[name] | |
hint = param.annotation | |
if not cls._type_hint_matches(arg, hint): | |
return False | |
return True | |
class OverloadDescriptor: | |
""" | |
A descriptor in place of the overloaded methods that is initialized | |
from the metaclass with the list of all overloads for given function. | |
Once we try to access this overloaded fucntion, we return a | |
BoundOverloadDispatcher that will decide which of the overloads should | |
be picked. | |
We're using a descriptor here to be able to capture the instance along | |
with the attempt to access the overloaded fucntion. This is important | |
because we then use this instance in the BoundOverloadDispatcher when | |
we don't find the correct overloads to check for them with super(). | |
""" | |
def __set_name__(self, owner: type[object], name: str) -> None: | |
""" | |
The descriptor protocol adds this method to make it simple | |
to obtain the name of the attribute set to this descriptor. | |
""" | |
self.owner = owner | |
self.name = name | |
def __init__(self, overload_list: OverloadList) -> None: | |
""" | |
The descriptor is initialized from the metaclass and receives | |
a list of all overload functions. | |
""" | |
if not isinstance(overload_list, OverloadList): | |
raise TypeError("Must use OverloadList.") | |
if not overload_list: | |
raise ValueError("The overload_list can't be empty") | |
self.overload_list = overload_list | |
def __repr__(self): | |
"""This will be the repr of all overloaded functions.""" | |
return f"{self.__class__.__qualname__}({self.overload_list!r}))" | |
def __get__(self, instance: object, owner: Optional[type[object]] = None): | |
""" | |
This method gets called whenever the overloaded method is accessed. | |
This mimics the default python behavior where accessing class.function | |
would give you the function object, but accessing instance.function | |
will give you a bound method that stores the function object and | |
auto-passes the self argument once it's callled, in our case, accessing | |
the overloaded function from a class gives you this descriptor and | |
accessing from an instance will returh a BoundOverloadDispatcher. | |
""" | |
# If the descriptor is accessed from the class directly, rather than | |
# from an initialized object, we return this descriptor (self) | |
if instance is None: | |
return self | |
# TODO: Consider using a dict cache with a composite hash of the | |
# values passed into the initialization of BoundOverloadDispatcher | |
# to avoid having to initialize a this class every thime it is accessed | |
return BoundOverloadDispatcher( | |
instance, self.owner, | |
self.name, self.overload_list | |
) | |
class OverloadMeta(type): | |
"""A metaclass that allows a class to have overloads for it's methods.""" | |
@classmethod | |
def __prepare__(cls, name: str, bases: list[type[object]]) -> OverloadDict: | |
""" | |
This is the method which returns the default empty dictionary | |
which will then be used for running exec on the class body as | |
'locals' dict. | |
We override this method to return our custom dictionary, that | |
will be able to support setting multiple functions with the | |
same name if they have function.__overload__ set to True. | |
""" | |
return OverloadDict() | |
def __new__(mcls, name: str, bases, namespace: OverloadDict, **kwargs): # type: ignore | |
""" | |
Override the class creation and change all captured overload lists | |
in the given namespace from __prepare__ to OverloadDescriptors, | |
which when accessed will return a BoundOverloadDispatcher that is | |
able to figure out which overload to use depending on the passed | |
call arguments and the signatures of the overloaded fucntions. | |
""" | |
overload_namespace = { | |
key: OverloadDescriptor(val) if isinstance(val, OverloadList) else val | |
for key, val in namespace.items() | |
} | |
return super().__new__(mcls, name, bases, overload_namespace, **kwargs) | |
def overload(f: T) -> T: | |
""" | |
The overload decorator. | |
By using this decorator in a class that uses the OverloadMeta | |
metaclass during it's creation, a method can have the same name | |
with a different amount of attributes or different type-hits for | |
those attributes and whenever such a method is called, the | |
appropriate method will be picked from all of the specified overloads | |
depending on the passed arguments. | |
This decorator alone doesn't use any special logic, all of the logic | |
is handled by the metaclass, this only specifies the decorated function | |
as an overloaded function and allows the metaclass to handle it | |
differently than regular methods. | |
""" | |
f.__overload__ = True | |
return f | |
class Overloadable(metaclass=OverloadMeta): | |
""" | |
Regular class that other classes can inherit from to also | |
inherit the OverloadMeta metaclass along with it, that allows | |
for the overloads to be made. | |
""" | |
class Example(Overloadable): | |
@overload | |
def bar(self, x: int): # type: ignore # noqa: F811 | |
print(f"Called Example.bar int overload: {x=!r}") | |
@overload | |
def bar(self, x: str): # type: ignore # noqa: F811 | |
print(f"Called Example.bar str overload: {x=!r}") | |
@overload | |
def bar(self, x: int, y: int): # type: ignore # noqa: F811 | |
print(f"Called Example.bar two argument overload: {x=!r} {y=!r}") | |
def foobar(self, x: str): | |
print(f"Called Example.foobar regular method: {x=!r}") | |
if __name__ == "__main__": | |
foo = Example() | |
foo.bar(1) # type: ignore | |
foo.bar("hi") # type: ignore | |
foo.bar(1, 8) # type: ignore | |
foo.foobar("hello") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment