Created
November 20, 2011 05:14
-
-
Save mvasilkov/1379824 to your computer and use it in GitHub Desktop.
Tail recursive power function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial, wraps | |
def power_tail(x, a, n): | |
assert n >= 0 | |
if n == 0: | |
return x | |
elif n & 1: | |
return partial(power_tail, x * a, a, n - 1) | |
else: | |
return partial(power_tail, x, a * a, n >> 1) | |
def tail_recursive(fun): | |
def tail(obj): | |
while callable(obj): | |
obj = obj() | |
return obj | |
return wraps(fun)(lambda *args: tail(fun(*args))) | |
@tail_recursive | |
def power(a, n): | |
""" | |
Tail recursive. | |
>>> power(2, 3) | |
8 | |
>>> power(8, 4) | |
4096 | |
""" | |
return power_tail(1, a, n) | |
def power_rec(a, n): | |
""" | |
Reference implementation. | |
>>> power(2, 3) | |
8 | |
>>> power(8, 4) | |
4096 | |
""" | |
assert n >= 0 | |
if n == 0: | |
return 1 | |
elif n & 1: | |
return a * power(a, n - 1) | |
else: | |
return power(a * a, n >> 1) | |
if __name__ == '__main__': | |
import doctest | |
doctest.testmod() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment