Last active
April 6, 2020 01:32
-
-
Save vxgmichel/590f3e0dbc1a3a841251881686313346 to your computer and use it in GitHub Desktop.
Stackless recursion using generators
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A decorator to turn a generator into a cached, stackless recursive function. | |
The yield keyword is used to send the arguments to pass to the recursive | |
function and retrieve the return value, i.e: | |
@stackless | |
def fibonacci(n): | |
if n < 2: | |
return n | |
return (yield (n - 1,)) + (yield (n - 2,)) | |
""" | |
def stackless(func=None, *, cached=True): | |
import functools | |
def decorator(func): | |
@functools.wraps(func) | |
def wrapper(*args): | |
cache = {} | |
result = None | |
stack = [(args, func(*args))] | |
while stack: | |
args, gen = stack[-1] | |
try: | |
args = gen.send(result) | |
except StopIteration as exc: | |
stack.pop() | |
result = exc.value | |
if cached: | |
cache[args] = result | |
continue | |
if cached and args in cache: | |
result = cache[args] | |
continue | |
result = None | |
stack.append((args, func(*args))) | |
return result | |
return wrapper | |
return decorator if func is None else decorator(func) | |
# Testing | |
import pytest | |
@stackless | |
def stackless_fib(n): | |
if n < 2: | |
return n | |
return (yield (n - 1,)) + (yield (n - 2,)) | |
@stackless(cached=False) | |
def slow_stackless_fib(n): | |
if n < 2: | |
return n | |
return (yield (n - 1,)) + (yield (n - 2,)) | |
def recursive_fib(n): | |
if n < 2: | |
return n | |
return recursive_fib(n - 1) + recursive_fib(n - 2) | |
def iterative_fib(n): | |
a, b = 0, 1 | |
for _ in range(n): | |
a, b = b, a + b | |
return a | |
@stackless | |
def stackless_sum(n): | |
if n == 0: | |
return 0 | |
return n + (yield (n - 1,)) | |
def recursive_sum(n): | |
if n == 0: | |
return 0 | |
return n + recursive_sum(n - 1) | |
@pytest.mark.parametrize( | |
"fib", [recursive_fib, stackless_fib, slow_stackless_fib] | |
) | |
def test_fib(fib): | |
n = 28 | |
assert fib(n) == iterative_fib(n) | |
def test_sum(): | |
n = 10 ** 5 | |
assert stackless_sum(n) == n * (n + 1) // 2 | |
with pytest.raises(RecursionError): | |
recursive_sum(n) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment