Last active
December 17, 2021 10:51
-
-
Save MischaPanch/30b25d82093cdef6577146af75badcff to your computer and use it in GitHub Desktop.
Allowing overloading of operators like +, *, @ and so on for functions in python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import operator as o | |
from typing import Callable, Union, Any, Type | |
log = logging.getLogger() | |
class _FunctionWrapper: | |
def __init__(self, function: Callable, name: str = None): | |
self.function = function | |
self.__name__ = function.__name__ if name is None else name | |
def __call__(self, *args, **kwargs): | |
return self.function(*args, **kwargs) | |
def __repr__(self): | |
return self.__name__ | |
def __str__(self): | |
return self.__name__ | |
class Numerical(_FunctionWrapper): | |
""" | |
Using this as decorator allows standard numerical operators to be used for combining functions with other callables | |
or with any non-callable objects to create new callables. The composed object will have a meaningful __name__ | |
and representation that are useful for inspection. | |
Once a function has been wrapped by this decorator, operators can be used on all other objects, provided that | |
the wrapped function is the first in the sequence of operations | |
Example composition: | |
>>> import numpy as np | |
>>> @Numerical | |
... def f(x): | |
... return 2 * x | |
>>> def g(x): | |
... return x**2 | |
>>> c = f + g | |
>>> c.__name__ | |
'f + g' | |
>>> c = c / g | |
>>> c.__name__ | |
'(f + g) / g' | |
>>> c = (c + np.array([1, 2])) * 5 | |
>>> c.__name__ | |
'((f + g) / g + [1 2]) * 5' | |
>>> c(10) | |
array([11., 16.]) | |
Example comparison: | |
>>> d = f < g | |
>>> d.__name__ | |
'f < g' | |
>>> d(10) | |
True | |
>>> d(1) | |
False | |
>>> d(np.array([1, 10])) | |
array([False, True]) | |
""" | |
def __add__(self, other: Union[Callable, Any]): | |
return _get_apply_binary_operator(self, other, o.add, "+", add_braces=False) | |
def __sub__(self, other): | |
return _get_apply_binary_operator(self, other, o.sub, "-") | |
def __rmul__(self, other): | |
return _get_apply_binary_operator(other, self, o.mul, "*") | |
def __mul__(self, other): | |
return _get_apply_binary_operator(self, other, o.mul, "*") | |
def __abs__(self): | |
return _get_apply_unary_operator(self, o.abs, name=f"|{self.__name__}|") | |
def __neg__(self): | |
return _get_apply_unary_operator(self, o.neg, operator_symbol="-") | |
def __pos__(self): | |
return _get_apply_unary_operator(self, o.pos, operator_symbol="+") | |
def __matmul__(self, other): | |
return _get_apply_binary_operator(self, other, o.matmul, "@") | |
def __rmatmul__(self, other): | |
return _get_apply_binary_operator(other, self, o.matmul, "@") | |
def __floordiv__(self, other): | |
return _get_apply_binary_operator(self, other, o.floordiv, "//") | |
def __truediv__(self, other): | |
return _get_apply_binary_operator(self, other, o.truediv, "/") | |
def __pow__(self, other, modulo=None): | |
return _get_apply_binary_operator(self, other, o.pow, "**") | |
def __divmod__(self, other): | |
return _get_apply_binary_operator(self, other, o.mod, "%") | |
def __le__(self, other): | |
return _get_apply_binary_operator(self, other, o.le, "<=") | |
def __lt__(self, other): | |
return _get_apply_binary_operator(self, other, o.lt, "<") | |
def __ge__(self, other): | |
return _get_apply_binary_operator(self, other, o.ge, ">=") | |
def __gt__(self, other): | |
return _get_apply_binary_operator(self, other, o.gt, ">") | |
class Boolean(_FunctionWrapper): | |
""" | |
Using this as decorator allows standard logical operators to be used for combining functions with other callables | |
or with any non-callable objects to create new callables. The composed object will have a meaningful __name__ | |
and representation that are useful for inspection. | |
Once a function has been wrapped by this decorator, operators can be used on all other objects, provided that | |
the wrapped function is the first in the sequence of operations. | |
Example: | |
>>> @Boolean | |
... def smaller2(x): | |
... return x < 2 | |
>>> def divisible_by4(x): | |
... return x % 4 == 0 | |
>>> and_composite = smaller2 & divisible_by4 | |
>>> and_composite.__name__ | |
'smaller2 & divisible_by4' | |
>>> and_composite(4) | |
False | |
>>> and_composite(-4) | |
True | |
""" | |
def __and__(self, other): | |
return _get_apply_binary_operator( | |
self, other, o.and_, "&", wrapper_class=Boolean | |
) | |
def __or__(self, other): | |
return _get_apply_binary_operator( | |
self, other, o.or_, "|", wrapper_class=Boolean | |
) | |
def __xor__(self, other): | |
return _get_apply_binary_operator( | |
self, other, o.xor, "^", wrapper_class=Boolean | |
) | |
def __invert__(self): | |
return _get_apply_unary_operator( | |
self, o.inv, operator_symbol="~", wrapper_class=Boolean | |
) | |
def _get_name(f: Union[Callable, Any]): | |
if not isinstance(f, Callable): | |
return str(f) | |
return f.__name__ | |
def _to_callable(f: Union[Callable, Any]) -> Callable: | |
return f if isinstance(f, Callable) else lambda *args, **kwargs: f | |
def _get_apply_binary_operator( | |
f1: Union[Callable, Any], | |
f2: Union[Callable, Any], | |
operator: Callable[[Any, Any], Any], | |
operator_symbol: str, | |
add_braces=True, | |
wrapper_class: Type[_FunctionWrapper] = Numerical, | |
): | |
""" | |
Returns an instance of the wrapper_class performing x -> operator(f1(x), f2(x)) where if some f is not | |
a callable, the value of f is taken instead of f(x) | |
:param f1: | |
:param f2: | |
:param operator: | |
:param operator_symbol: | |
:param add_braces: whether to add braces around the function name in the result's name | |
:param wrapper_class: | |
:return: | |
""" | |
def maybe_add_braces(n: str): | |
if add_braces and " " in n: | |
return f"({n})" | |
return n | |
name1, name2 = maybe_add_braces(_get_name(f1)), maybe_add_braces(_get_name(f2)) | |
f1, f2 = _to_callable(f1), _to_callable(f2) | |
def composed_function(*args, **kwargs): | |
res1, res2 = f1(*args, **kwargs), f2(*args, **kwargs) | |
return operator(res1, res2) | |
composed_name = f"{name1} {operator_symbol} {name2}" | |
return wrapper_class(composed_function, name=composed_name) | |
def _get_apply_unary_operator( | |
f: Callable, | |
operator: Callable[[Any], Any], | |
name: str = None, | |
operator_symbol: str = None, | |
add_braces=True, | |
wrapper_class: Type[_FunctionWrapper] = Numerical, | |
): | |
""" | |
Returns an instance of the wrapper_class performing x -> operator(f(x)) where if f is not | |
a callable, the value of f is taken instead of f(x) | |
:param f: | |
:param operator: | |
:param name: if given, operator_symbol and add_braces will be ignored | |
:param operator_symbol: | |
:param add_braces: whether to add braces around the function name in the result's name | |
:param wrapper_class: | |
:return: | |
""" | |
def operator_applied(*args, **kwargs): | |
return operator(f(*args, **kwargs)) | |
if name is not None: | |
if add_braces: | |
log.debug( | |
f"Ignoring add_braces=True b/c explicit name was provided: {name}" | |
) | |
wrapper_class(operator_applied, name=name) | |
if operator_symbol is None: | |
raise ValueError( | |
"operator_symbol cannot be None when name is not provided explicitly" | |
) | |
name = f.__name__ | |
if add_braces and " " in name: | |
name = f"({name})" | |
name = operator_symbol + name | |
return wrapper_class(operator_applied, name) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment