Created
May 26, 2019 09:02
-
-
Save vxgmichel/1b663b51107020a8b7f4f3ae051b24f2 to your computer and use it in GitHub Desktop.
Python helpers for continued fractions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Helpers for continued fractions""" | |
import math | |
import itertools | |
def continued_fraction(n, d): | |
while d: | |
q, r = divmod(n, d) | |
n, d = d, r | |
yield q | |
def alternative_continued_fraction(n, d): | |
gen = continued_fraction(n, d) | |
p = next(gen) | |
for q in gen: | |
yield p | |
p = q | |
yield p - 1 | |
yield 1 | |
def convergents(n, d): | |
hh, kk, h, k = 0, 1, 1, 0 | |
for x in continued_fraction(n, d): | |
hh, kk, h, k = h, k, h * x + hh, k * x + kk | |
yield h, k | |
def best_approximations(n, d): | |
# Initialize | |
gen = continued_fraction(n, d) | |
x = next(gen) | |
hh, kk, h, k = 1, 0, x, 1 | |
yield x, 1 | |
# Loop over continued fraction | |
for x in gen: | |
# Tricky semiconvergent | |
if x % 2 == 0: | |
m = x // 2 | |
hm, km = h * m + hh, k * m + kk | |
e1 = km * abs(n * k - d * h) | |
e2 = k * abs(n * km - d * hm) | |
if e1 > e2: | |
yield hm, km | |
# Loop over better semiconvergents | |
s = x // 2 + 1 | |
for m in range(s, x + 1): | |
yield h * m + hh, k * m + kk | |
# Generate next convergent | |
hh, kk, h, k = h, k, h * x + hh, k * x + kk | |
def zfunction(g1, g2): | |
# Compute the convergents for the common terms | |
hh, kk, h, k = 0, 1, 1, 0 | |
for x, y in itertools.zip_longest(g1, g2, fillvalue=math.inf): | |
if x != y: | |
break | |
hh, kk, h, k = h, k, h * x + hh, k * x + kk | |
# No difference between the two continued fractions | |
else: | |
raise ValueError | |
# Add the extra term and return the convergent | |
z = min(x, y) + 1 | |
return h * z + hh, k * z + kk | |
def best_rational_zf(n1, d1, n2, d2): | |
# Edge case | |
if n1 * d2 >= d1 * n2: | |
raise ValueError | |
# Try up to 4 representations | |
fs = continued_fraction, alternative_continued_fraction | |
for f1 in fs: | |
for f2 in fs: | |
# Return the first matching result | |
h, k = zfunction(f1(n1, d1), f2(n2, d2)) | |
if n1 * k < h * d1 and h * d2 < n2 * k: | |
return h, k | |
def best_rational_cv(n1, d1, n2, d2): | |
# Edge case | |
if n1 * d2 >= d1 * n2: | |
raise ValueError | |
if n1 // d1 + 1 < n2 // d2: | |
return n1 // d1 + 1, 1 | |
# Compute average | |
n = n1 * d2 + n2 * d1 | |
d = d1 * d2 * 2 | |
# Find first convergent | |
hh, kk, h, k = 0, 1, 1, 0 | |
for x in continued_fraction(n, d): | |
a, b = h * x + hh, k * x + kk | |
if n1 * b < a * d1 and a * d2 < n2 * b: | |
break | |
hh, kk, h, k = h, k, a, b | |
# Compute both m candidate | |
try: | |
m1 = 1 + (d1 * hh - n1 * kk) // (n1 * k - d1 * h) | |
except ZeroDivisionError: | |
m1 = x | |
try: | |
m2 = 1 + (d2 * hh - n2 * kk) // (n2 * k - d2 * h) | |
except ZeroDivisionError: | |
m2 = x | |
# Apply the proper candidate | |
m = min(m1, m2) | |
return h * m + hh, k * m + kk | |
def best_rational_ba(n1, d1, n2, d2): | |
# Edge cases | |
if n1 * d2 >= d1 * n2: | |
raise ValueError | |
if n1 // d1 + 1 < n2 // d2: | |
return n1 // d1 + 1, 1 | |
# Compute average | |
n = n1 * d2 + n2 * d1 | |
d = d1 * d2 * 2 | |
# Find the first matching best approximation | |
for h, k in best_approximations(n, d): | |
if n1 * k < h * d1 and h * d2 < n2 * k: | |
return h, k | |
def best_rational(*args, method="zf"): | |
# Fast method | |
if method == "zf": | |
return best_rational_zf(*args) | |
# Fast method | |
if method == "cv": | |
return best_rational_cv(*args) | |
# Slow method | |
if method == "ba": | |
return best_rational_ba(*args) | |
# Invalid method | |
raise ValueError | |
# Tests | |
import pytest | |
import random | |
def test_trivial(): | |
assert list(best_approximations(1, 1)) == [(1, 1)] | |
assert list(best_approximations(2, 1)) == [(2, 1)] | |
assert list(best_approximations(3, 1)) == [(3, 1)] | |
assert list(best_approximations(1, 1)) == [(1, 1)] | |
assert list(best_approximations(1, 2)) == [(0, 1), (1, 2)] | |
assert list(best_approximations(1, 3)) == [(0, 1), (1, 2), (1, 3)] | |
def test_5_dot_4375(): | |
assert list(continued_fraction(584375, 100000)) == [ | |
5, 1, 5, 2, 2] | |
assert list(convergents(584375, 100000)) == [ | |
(5, 1), (6, 1), (35, 6), (76, 13), (187, 32)] | |
assert list(best_approximations(584375, 100000)) == [ | |
(5, 1), (6, 1), (23, 4), (29, 5), | |
(35, 6), (76, 13), (111, 19), (187, 32)] | |
def test_dot_84375(): | |
assert list(continued_fraction(84375, 100000)) == [ | |
0, 1, 5, 2, 2] | |
assert list(convergents(84375, 100000)) == [ | |
(0, 1), (1, 1), (5, 6), (11, 13), (27, 32)] | |
assert list(best_approximations(84375, 100000)) == [ | |
(0, 1), (1, 1), (3, 4), (4, 5), (5, 6), (11, 13), (16, 19), (27, 32)] | |
def test_random(): | |
random.seed(0) | |
m = 10 ** 10 | |
n, d = random.randint(1, m), random.randint(1, m) | |
pe, pk = 1, 0 | |
for h, k in best_approximations(n, d): | |
assert math.gcd(h, k) == 1 | |
e = abs(n * k - h * d) / (d * k) | |
assert e < pe | |
assert k > pk | |
pe, pk = e, k | |
@pytest.mark.parametrize("method", ["zf", "cv", "ba"]) | |
def test_best_rational(method): | |
assert best_rational(1234, 100, 1235, 100, method=method) == (284, 23) | |
for a, b, c, d in itertools.product(range(1, 20), repeat=4): | |
if a / b < c / d: | |
h, k = best_rational(a, b, c, d, method=method) | |
assert a / b < h / k < c | |
assert best_rational(a, b, c, d) == (h, k) | |
else: | |
with pytest.raises(ValueError): | |
best_rational(a, b, c, d, method=method) | |
# Main execution | |
import sys | |
from decimal import Decimal, getcontext | |
if __name__ == "__main__": | |
getcontext().prec = 50 | |
a, b = map(Decimal, sys.argv[1:]) | |
h, k = best_rational(*a.as_integer_ratio(), *b.as_integer_ratio()) | |
print(h, "/", k, "\n =", Decimal(h) / Decimal(k)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment