Last active
August 17, 2022 20:34
-
-
Save erewok/9bae82e851e37222645fc8b5b4c32470 to your computer and use it in GitHub Desktop.
Python function composition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""compose.py - Function composition in Python | |
""" | |
# Python module for function composition | |
from functools import reduce, wraps | |
from itertools import chain | |
from typing import Callable, Generic, TypeVar | |
def _compose(second_function, first_function): | |
""" | |
Runs `second_function` after first_function | |
`first_function` is run on the input | |
and `second_function` is run on the result | |
""" | |
@wraps(second_function) | |
@wraps(first_function) | |
def inner(*args, **kwargs): | |
return second_function(first_function(*args, **kwargs)) | |
return inner | |
def flip(fn1, fn2): | |
"""Flips composition""" | |
return _compose(fn2, fn1) | |
C = TypeVar("C") | |
class compose(Generic[C]): | |
""" | |
compose(function1, function2, *other_funcs) -> new-function | |
Create a new function from the composition of two functions. | |
`function1` runs "after" `function2` and on the result of `function2`. | |
Example: | |
def to_upper(s): return s.upper() | |
def join_dots(s): return ".".join(s) | |
assert compose(join_dots, to_upper)("hello") == "H.E.L.L.O" | |
Any other functions supplied will be added on at the *beginning* of the | |
resulting computation, for example: | |
compose(f, g, h) -> f(g(h())) | |
""" | |
__slots__ = "composition", "__name__" | |
def __init__( | |
self, func_after: C | Callable, func_before: C | Callable, *other_funcs | |
): | |
self.__name__ = f"{func_after.__name__} . {func_before.__name__}" | |
if not other_funcs: | |
self.composition = _compose(func_after, func_before) | |
else: | |
func_chain_reversed = iter(chain( | |
other_funcs[::-1], | |
(func_before, func_after) | |
)) | |
self.composition = reduce(flip, func_chain_reversed) | |
def __call__(self, *args, **kwargs): | |
"""Invoke this function composition""" | |
return self.composition(*args, **kwargs) | |
def run_after(self, func: C | Callable): | |
"""Run `func` and then run this function composition on its result""" | |
self.composition = _compose(self.composition, func) | |
return self.composition | |
def run_before(self, func: C | Callable): | |
"""Run this function composition and then run `func` on its result""" | |
self.composition = _compose(func, self.composition) | |
return self.composition | |
def __and__(self, func: C): | |
""" | |
Overloads the `&` operator for `run_after` function composition | |
Results in a TypeError if called with something that's not | |
an instance of `compose`. | |
""" | |
return self.run_after(func) | |
def __repr__(self): | |
return f"<λ({self.__name__}>)" | |
# # # # # TESTS # # # # # | |
def test_compose_func(): | |
def add1(n): return n + 1 | |
def mul3(n): return n * 3 | |
def div10(n): return n // 10 | |
assert _compose(mul3, add1)(8) == 27 | |
assert _compose(div10, compose(mul3, add1))(8) == 2 | |
def to_title(s): return s.title() | |
def split_spaces(s): return s.split(" ") | |
def join_dots(strlist): return ".".join(strlist) | |
def replace1s(s): return s.replace("1", "ø") | |
def replace2s(s): return s.replace("2", "ü") | |
def debug(val): | |
print(val) | |
return val | |
def test_compose_class(): | |
test_list = [ | |
"hello darkness my old ", | |
"friend I've come to talk ", | |
"with you again ", | |
"because a vision softly creeping " | |
] | |
titled_result = next(iter(map( | |
compose(split_spaces, to_title).run_before(join_dots), | |
test_list | |
))) | |
assert titled_result.startswith("Hello.Darkness.My.Old.") | |
other_titled_result = next(iter(map( | |
compose(join_dots, split_spaces).run_after(to_title), | |
test_list | |
))) | |
assert other_titled_result == titled_result | |
# Try multiple compositions | |
replace_nums = compose( | |
replace1s, | |
replace2s, | |
join_dots | |
) | |
power_of_2s = compose(str, debug, lambda n: 2**n) | |
operators_result1 = tuple(map( | |
replace_nums & power_of_2s.run_after(debug), | |
range(1, 10) | |
)) | |
assert operators_result1 == ( | |
'ü', '4', '8', 'ø.6', '3.ü', '6.4', 'ø.ü.8', 'ü.5.6', '5.ø.ü' | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment