Created
January 30, 2011 05:03
-
-
Save ejconlon/802557 to your computer and use it in GitHub Desktop.
Tail-Recursion helper in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Tail-Recursion helper in Python. | |
Inspired by the trampoline function at | |
http://jasonmbaker.com/tail-recursion-in-python-using-pysistence | |
Tail-recursive functions return calls to tail-recursive functions | |
(themselves, most of the time). For example, this is tail-recursive: | |
sum [] acc = acc | |
sum (x:xs) = sum xs (acc+x) | |
And this is not: | |
fib n | n == 0 || n == 1 = 1 | |
| otherwise = (fib (n-1)) + (fib (n-2)) | |
because fib n returns an application of (+), not directly of fib. | |
Suppose we wanted to write sum in Python like we could in Haskell: | |
""" | |
# iterator must have a has_next method... | |
def nontrampsum(iterator, accumulator): | |
if not iterator.has_next(): | |
return accumulator | |
else: | |
head = iterator.next() | |
accumulator += head | |
return nontrampsum(iterator, accumulator) | |
""" | |
It looks elegant, but would blow up the stack pretty quickly. | |
Python will fully evaluate the recurisive call before returning, unlike | |
lazier Haskell. | |
We'll need some help: | |
""" | |
# Factory for consuming tail-recursive functions | |
# that return partially applied TR functions | |
def trampoline(f, *args, **kwargs): | |
def trampolined_f(*args, **kwargs): | |
result = f(*args, **kwargs) | |
while callable(result): | |
result = result() | |
return result | |
return trampolined_f | |
# Creates a 'suspension' of f | |
# Rreturns a function of zero-arity | |
# functools.partial does more though... | |
def partial(f, *args, **kwargs): | |
def partial_f(): | |
return f(*args, **kwargs) | |
return partial_f | |
""" | |
First, we can make our tail-recursive function not directly call itself, | |
but instead return a closure in which it is applied. Then we'll decorate | |
it with trampoline to call the suspensions it returns until the base case of | |
the recursion is reached. | |
""" | |
def trampsum_inner(iterator, acc): | |
if not iterator.has_next(): | |
return acc | |
else: | |
head = iterator.next() | |
acc += head | |
return partial(trampsum_inner, iterator, acc) | |
trampsum = trampoline(trampsum_inner) | |
""" | |
And a digression: we'll need to define an iterator with a has_next method. | |
I'd like to be able to pattern-match on iterators like lists in Haskell | |
sum [] acc = acc | |
sum (x:xs) = sum xs (acc+x) | |
We can just wrap an iterator and look ahead lazily. | |
""" | |
import collections | |
class LookAheadIterator(collections.Iterator): | |
def __init__(self, wrapped): | |
self._wrapped = iter(wrapped) | |
self._need_to_advance = True | |
self._has_next = False | |
self._cache = None | |
def has_next(self): | |
if self._need_to_advance: | |
self._advance() | |
return self._has_next | |
def _advance(self): | |
try: | |
self._cache = self._wrapped.next() | |
self._has_next = True | |
except StopIteration: | |
self._has_next = False | |
self._need_to_advance = False | |
def next(self): | |
if self._need_to_advance: | |
self._advance() | |
if self._has_next: | |
self._need_to_advance = True | |
return self._cache | |
else: | |
raise StopIteration() | |
def __next__(self): | |
self.next() | |
""" | |
Let's prove (sadly) that it's not the speediest: | |
""" | |
import cProfile | |
def test(f): | |
iterator = LookAheadIterator(xrange(1000000)) | |
accumulator = 0 | |
print f(iterator, accumulator) | |
print "Summing with built-in sum" | |
cProfile.run('test(sum)') | |
print "Summing with trampolined sum" | |
cProfile.run('test(trampsum)') | |
""" | |
499999500000 | |
2000009 function calls in 2.254 CPU seconds | |
Ordered by: standard name | |
ncalls tottime percall cumtime percall filename:lineno(function) | |
1 0.000 0.000 2.254 2.254 <string>:1(<module>) | |
1 0.000 0.000 0.000 0.000 _abcoll.py:66(__iter__) | |
1000001 0.909 0.000 1.747 0.000 cool.py:107(next) | |
1 0.000 0.000 2.254 2.254 cool.py:125(test) | |
1 0.000 0.000 0.000 0.000 cool.py:88(__init__) | |
1000001 0.838 0.000 0.838 0.000 cool.py:99(_advance) | |
1 0.000 0.000 0.000 0.000 {iter} | |
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects} | |
1 0.506 0.506 2.254 2.254 {sum} | |
499999500000 | |
7000010 function calls in 6.139 CPU seconds | |
Ordered by: standard name | |
ncalls tottime percall cumtime percall filename:lineno(function) | |
1 0.000 0.000 6.139 6.139 <string>:1(<module>) | |
1000000 0.588 0.000 0.588 0.000 cool.py:107(next) | |
1 0.000 0.000 6.139 6.139 cool.py:125(test) | |
1 0.761 0.761 6.139 6.139 cool.py:44(trampolined_f) | |
1000000 0.358 0.000 0.358 0.000 cool.py:54(partial) | |
1000000 0.784 0.000 5.242 0.000 cool.py:55(partial_f) | |
1000001 1.763 0.000 4.457 0.000 cool.py:66(trampsum_inner) | |
1 0.000 0.000 0.000 0.000 cool.py:88(__init__) | |
1000001 0.736 0.000 1.748 0.000 cool.py:94(has_next) | |
1000001 1.012 0.000 1.012 0.000 cool.py:99(_advance) | |
1000001 0.136 0.000 0.136 0.000 {callable} | |
1 0.000 0.000 0.000 0.000 {iter} | |
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects} | |
But it won't topple the stack: | |
""" | |
#print cProfile.run('test(nontrampsum)') | |
# RuntimeError: maximum recursion depth exceeded while calling a Python object | |
""" | |
We could fold with it: | |
""" | |
def foldl_inner(f, accumulator, iterator): | |
if not iterator.has_next(): | |
return accumulator | |
else: | |
head = iterator.next() | |
accumulator = f(accumulator, head) | |
return partial(foldl_inner, f, accumulator, iterator) | |
foldl = trampoline(foldl_inner) | |
def add(a, b): return a + b | |
def foldlsum(iterator, accumulator): | |
return foldl(add, accumulator, iterator) | |
print "Summing with trampolined foldl" | |
cProfile.run('test(foldlsum)') | |
""" | |
499999500000 | |
8000011 function calls in 6.874 CPU seconds | |
Ordered by: standard name | |
ncalls tottime percall cumtime percall filename:lineno(function) | |
1 0.000 0.000 6.874 6.874 <string>:1(<module>) | |
1000000 0.601 0.000 0.601 0.000 cool.py:107(next) | |
1 0.000 0.000 6.874 6.874 cool.py:125(test) | |
1000001 2.159 0.000 5.109 0.000 cool.py:187(foldl_inner) | |
1000000 0.206 0.000 0.206 0.000 cool.py:196(add) | |
1 0.000 0.000 6.874 6.874 cool.py:198(foldlsum) | |
1 0.824 0.824 6.874 6.874 cool.py:44(trampolined_f) | |
1000000 0.389 0.000 0.389 0.000 cool.py:54(partial) | |
1000000 0.819 0.000 5.927 0.000 cool.py:55(partial_f) | |
1 0.000 0.000 0.000 0.000 cool.py:88(__init__) | |
1000001 0.750 0.000 1.754 0.000 cool.py:94(has_next) | |
1000001 1.004 0.000 1.004 0.000 cool.py:99(_advance) | |
1000001 0.122 0.000 0.122 0.000 {callable} | |
1 0.000 0.000 0.000 0.000 {iter} | |
1 0.000 0.000 0.000 0.000 {method 'disable' of '_lsprof.Profiler' objects} | |
Use it wisely, I guess. (Or not at all.) Look into functools.partial too. | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment