Last active
November 13, 2021 10:07
-
-
Save hurryabit/8d1a62c5824271e7902a6d5b093d4a2c to your computer and use it in GitHub Desktop.
Easily eliminate recursion in Python using generators
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
def recurse(f): | |
""" | |
Decorator that turns a "yield transformed" recursive function into a | |
function that computes the same value as the actual recursive function but | |
does not blow the stack, even for large inputs. The "yield transformed" | |
version of a recursive function `f(x)` is obtained by replacing all | |
recursive calls `f(x')` into `(yield x')`. | |
Let us look the factorial function as an example. The following recurisve | |
implementation will blow the stack for sufficiently large values of `n`: | |
``` | |
def factorial(n): | |
if n <= 0: | |
return 1 | |
else: | |
return n * factorial(n - 1) | |
``` | |
Its "yield transformed" version annotated with `@recurse` will compute | |
exactly the same function but not blow the stack: | |
``` | |
@recurse | |
def factorial(n): | |
if n <= 0: | |
return 1 | |
else: | |
return n * (yield (n - 1)) | |
``` | |
The factorial function might be a questionable candidate to demonstrate a | |
general technique for transforming recursive functions into iterative | |
functions since it is almost trivial to do so by hand. However, the | |
implementation of | |
[Tarjan's strongly connected components algorithm](https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm) | |
below should provide a better glimpse of how intricate it can be to | |
eliminate recursion without a general technique to do so. | |
""" | |
def recursive(x): | |
stack = [f(x)] | |
y = None | |
while len(stack) > 0: | |
gen = stack.pop() | |
try: | |
x = gen.send(y) | |
stack.append(gen) | |
stack.append(f(x)) | |
y = None | |
except StopIteration as stop: | |
y = stop.value | |
return y | |
return recursive | |
def triangular_bad(n): | |
""" | |
Recursively computes the n-th triangular number `1 + 2 + ... + n`. We use | |
this function as an example of how to blow the stack. | |
""" | |
if n <= 0: | |
return 0 | |
else: | |
return n + triangular_bad(n - 1) | |
@recurse | |
def triangular_good(n): | |
""" | |
Computes the n-th triangular number `1 + 2 + ... + n` by means of the | |
@recurse decorator and the yield transformation. We use this function to | |
demonstrate that this technique does not blow the stack even for inputs | |
far beyond the recursion limit. | |
""" | |
if n <= 0: | |
return 0 | |
else: | |
return n + (yield (n - 1)) | |
RECURSION_LIMIT = sys.getrecursionlimit() | |
try: | |
triangular_bad(RECURSION_LIMIT) | |
assert False, "The stack did unexpectedly not blow." | |
except RecursionError: | |
pass | |
assert triangular_good(100 * RECURSION_LIMIT) == \ | |
(100 * RECURSION_LIMIT) * (100 * RECURSION_LIMIT + 1) // 2 | |
def tarjan(graph): | |
""" | |
Returns a list of the strongly connected components of a graph. The | |
graph's vertices must be named `0, 1, ..., n - 1` for some number `n` and | |
`graph[v]` must list all the nodes `w` for which there is an edge from `v` | |
to `w`. For example, the graph | |
0 ---> 1 ---> 2 | |
| | | |
V V | |
3 ---> 4 | |
would be represented as | |
``` | |
graph = [[1], [2, 3], [4], [4], []] | |
``` | |
""" | |
n = len(graph) | |
index = 0 | |
indices = n * [-1] | |
lowlinks = n * [-1] | |
components = [] | |
stack = [] | |
on_stack = set() | |
@recurse | |
def dfs(v): | |
nonlocal index, indices, lowlinks, components, stack | |
indices[v] = index | |
lowlinks[v] = index | |
index += 1 | |
stack.append(v) | |
on_stack.add(v) | |
for w in graph[v]: | |
if indices[w] < 0: | |
# Below is the "yield transformed" recursive call `dfs(w)`. | |
# Just imagine you had to eliminate this deeply buried | |
# recursive call by hand... | |
yield w | |
lowlinks[v] = min(lowlinks[v], lowlinks[w]) | |
elif w in on_stack: | |
lowlinks[v] = min(lowlinks[v], indices[w]) | |
if lowlinks[v] == indices[v]: | |
component = [] | |
w = -1 | |
while w != v: | |
w = stack.pop() | |
on_stack.remove(w) | |
component.append(w) | |
components.append(component) | |
for v in range(n): | |
if indices[v] < 0: | |
dfs(v) | |
return components | |
assert tarjan([[1], [2, 3], [1, 4], [2], []]) == [[4], [3, 2, 1], [0]] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment