Last active
October 18, 2021 13:34
-
-
Save vxgmichel/ba03a07defd7742c48221b3537e9b295 to your computer and use it in GitHub Desktop.
Write recursive functions as coroutines to make them stackless
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Write recursive functions as coroutines to make them stackless. | |
Example: | |
>>> @runable # Allow the function to run with a simple sync call | |
... @functools.cache # Cache the results | |
... @stackless # Make the coroutine stackless | |
... async def fib(n, m=10**9): | |
... if n <= 1: | |
... return 1 if n == 1 else 0 | |
... result = await fib(n - 1, m) + await fib(n - 2, m) | |
... return result % m | |
>>> fib(10**6) | |
242546875 | |
Note that the order of the decorators is important: | |
- `stackless` should be applied first to make the function compatible with `cache` | |
- `runable` should be applied last to bridge the sync and async world | |
The `stackless_cache` decorator can be used as a shortcut. Also note that it adds | |
an important overhead to the execution time (~10 times slower). | |
""" | |
import functools | |
from collections.abc import Awaitable | |
class CachableStacklessCoroutine(Awaitable): | |
__slots__ = ("coro", "result") | |
unset = object() | |
def __init__(self, coro): | |
self.coro = coro | |
self.result = self.unset | |
def __await__(self): | |
if self.result is not self.unset: | |
return self.result | |
self.result = yield self.coro.__await__() | |
return self.result | |
def run_stackless(coro): | |
coro_stack = [] | |
current_coro = coro.__await__() | |
current_value = None | |
while True: | |
try: | |
new_coro = current_coro.send(current_value) | |
except StopIteration as exc: | |
if not coro_stack: | |
return exc.value | |
current_coro = coro_stack.pop() | |
current_value = exc.value | |
else: | |
coro_stack.append(current_coro) | |
current_coro = new_coro | |
current_value = None | |
def stackless(corofn): | |
@functools.wraps(corofn) | |
def wrapper(*args, **kwargs): | |
return CachableStacklessCoroutine(corofn(*args, **kwargs)) | |
return wrapper | |
def runable(corofn): | |
try: | |
assert not runable.running | |
except AttributeError: | |
runable.running = False | |
@functools.wraps(corofn) | |
def wrapper(*args, **kwargs): | |
coro = corofn(*args, **kwargs) | |
if runable.running: | |
return coro | |
try: | |
runable.running = True | |
return run_stackless(coro) | |
finally: | |
runable.running = False | |
return wrapper | |
def stackless_cache(corofn): | |
return runable(functools.lru_cache(maxsize=None)(stackless(corofn))) | |
def test_fib(): | |
@stackless_cache | |
async def fib(n, m=10 ** 9): | |
if n <= 1: | |
return 1 if n == 1 else 0 | |
result = await fib(n - 1, m) + await fib(n - 2, m) | |
return result % m | |
assert fib(100) == 261915075 | |
assert fib(1000) == 849228875 | |
assert fib(10 ** 4) == 947366875 | |
assert fib(10 ** 6) == 242546875 | |
def test_fib_recursive_ref(): | |
@functools.cache | |
def fib(n, m=10 ** 9): | |
if n <= 1: | |
return 1 if n == 1 else 0 | |
result = fib(n - 1, m) + fib(n - 2, m) | |
return result % m | |
# Populate cache to avoid recursion errors | |
for x in range(0, 10 ** 6, 400): | |
fib(x) | |
assert fib(100) == 261915075 | |
assert fib(1000) == 849228875 | |
assert fib(10 ** 4) == 947366875 | |
assert fib(10 ** 6) == 242546875 | |
def test_fib_iterative_ref(): | |
def fib(n, m=10 ** 9): | |
a, b = 0, 1 | |
for _ in range(n): | |
a, b = b, (a + b) % m | |
return a | |
assert fib(100) == 261915075 | |
assert fib(1000) == 849228875 | |
assert fib(10 ** 4) == 947366875 | |
assert fib(10 ** 6) == 242546875 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment