Skip to content

Instantly share code, notes, and snippets.

@vxgmichel
Created May 26, 2019 09:02
Show Gist options
  • Save vxgmichel/1b663b51107020a8b7f4f3ae051b24f2 to your computer and use it in GitHub Desktop.
Save vxgmichel/1b663b51107020a8b7f4f3ae051b24f2 to your computer and use it in GitHub Desktop.
Python helpers for continued fractions
"""Helpers for continued fractions"""
import math
import itertools
def continued_fraction(n, d):
while d:
q, r = divmod(n, d)
n, d = d, r
yield q
def alternative_continued_fraction(n, d):
gen = continued_fraction(n, d)
p = next(gen)
for q in gen:
yield p
p = q
yield p - 1
yield 1
def convergents(n, d):
hh, kk, h, k = 0, 1, 1, 0
for x in continued_fraction(n, d):
hh, kk, h, k = h, k, h * x + hh, k * x + kk
yield h, k
def best_approximations(n, d):
# Initialize
gen = continued_fraction(n, d)
x = next(gen)
hh, kk, h, k = 1, 0, x, 1
yield x, 1
# Loop over continued fraction
for x in gen:
# Tricky semiconvergent
if x % 2 == 0:
m = x // 2
hm, km = h * m + hh, k * m + kk
e1 = km * abs(n * k - d * h)
e2 = k * abs(n * km - d * hm)
if e1 > e2:
yield hm, km
# Loop over better semiconvergents
s = x // 2 + 1
for m in range(s, x + 1):
yield h * m + hh, k * m + kk
# Generate next convergent
hh, kk, h, k = h, k, h * x + hh, k * x + kk
def zfunction(g1, g2):
# Compute the convergents for the common terms
hh, kk, h, k = 0, 1, 1, 0
for x, y in itertools.zip_longest(g1, g2, fillvalue=math.inf):
if x != y:
break
hh, kk, h, k = h, k, h * x + hh, k * x + kk
# No difference between the two continued fractions
else:
raise ValueError
# Add the extra term and return the convergent
z = min(x, y) + 1
return h * z + hh, k * z + kk
def best_rational_zf(n1, d1, n2, d2):
# Edge case
if n1 * d2 >= d1 * n2:
raise ValueError
# Try up to 4 representations
fs = continued_fraction, alternative_continued_fraction
for f1 in fs:
for f2 in fs:
# Return the first matching result
h, k = zfunction(f1(n1, d1), f2(n2, d2))
if n1 * k < h * d1 and h * d2 < n2 * k:
return h, k
def best_rational_cv(n1, d1, n2, d2):
# Edge case
if n1 * d2 >= d1 * n2:
raise ValueError
if n1 // d1 + 1 < n2 // d2:
return n1 // d1 + 1, 1
# Compute average
n = n1 * d2 + n2 * d1
d = d1 * d2 * 2
# Find first convergent
hh, kk, h, k = 0, 1, 1, 0
for x in continued_fraction(n, d):
a, b = h * x + hh, k * x + kk
if n1 * b < a * d1 and a * d2 < n2 * b:
break
hh, kk, h, k = h, k, a, b
# Compute both m candidate
try:
m1 = 1 + (d1 * hh - n1 * kk) // (n1 * k - d1 * h)
except ZeroDivisionError:
m1 = x
try:
m2 = 1 + (d2 * hh - n2 * kk) // (n2 * k - d2 * h)
except ZeroDivisionError:
m2 = x
# Apply the proper candidate
m = min(m1, m2)
return h * m + hh, k * m + kk
def best_rational_ba(n1, d1, n2, d2):
# Edge cases
if n1 * d2 >= d1 * n2:
raise ValueError
if n1 // d1 + 1 < n2 // d2:
return n1 // d1 + 1, 1
# Compute average
n = n1 * d2 + n2 * d1
d = d1 * d2 * 2
# Find the first matching best approximation
for h, k in best_approximations(n, d):
if n1 * k < h * d1 and h * d2 < n2 * k:
return h, k
def best_rational(*args, method="zf"):
# Fast method
if method == "zf":
return best_rational_zf(*args)
# Fast method
if method == "cv":
return best_rational_cv(*args)
# Slow method
if method == "ba":
return best_rational_ba(*args)
# Invalid method
raise ValueError
# Tests
import pytest
import random
def test_trivial():
assert list(best_approximations(1, 1)) == [(1, 1)]
assert list(best_approximations(2, 1)) == [(2, 1)]
assert list(best_approximations(3, 1)) == [(3, 1)]
assert list(best_approximations(1, 1)) == [(1, 1)]
assert list(best_approximations(1, 2)) == [(0, 1), (1, 2)]
assert list(best_approximations(1, 3)) == [(0, 1), (1, 2), (1, 3)]
def test_5_dot_4375():
assert list(continued_fraction(584375, 100000)) == [
5, 1, 5, 2, 2]
assert list(convergents(584375, 100000)) == [
(5, 1), (6, 1), (35, 6), (76, 13), (187, 32)]
assert list(best_approximations(584375, 100000)) == [
(5, 1), (6, 1), (23, 4), (29, 5),
(35, 6), (76, 13), (111, 19), (187, 32)]
def test_dot_84375():
assert list(continued_fraction(84375, 100000)) == [
0, 1, 5, 2, 2]
assert list(convergents(84375, 100000)) == [
(0, 1), (1, 1), (5, 6), (11, 13), (27, 32)]
assert list(best_approximations(84375, 100000)) == [
(0, 1), (1, 1), (3, 4), (4, 5), (5, 6), (11, 13), (16, 19), (27, 32)]
def test_random():
random.seed(0)
m = 10 ** 10
n, d = random.randint(1, m), random.randint(1, m)
pe, pk = 1, 0
for h, k in best_approximations(n, d):
assert math.gcd(h, k) == 1
e = abs(n * k - h * d) / (d * k)
assert e < pe
assert k > pk
pe, pk = e, k
@pytest.mark.parametrize("method", ["zf", "cv", "ba"])
def test_best_rational(method):
assert best_rational(1234, 100, 1235, 100, method=method) == (284, 23)
for a, b, c, d in itertools.product(range(1, 20), repeat=4):
if a / b < c / d:
h, k = best_rational(a, b, c, d, method=method)
assert a / b < h / k < c
assert best_rational(a, b, c, d) == (h, k)
else:
with pytest.raises(ValueError):
best_rational(a, b, c, d, method=method)
# Main execution
import sys
from decimal import Decimal, getcontext
if __name__ == "__main__":
getcontext().prec = 50
a, b = map(Decimal, sys.argv[1:])
h, k = best_rational(*a.as_integer_ratio(), *b.as_integer_ratio())
print(h, "/", k, "\n =", Decimal(h) / Decimal(k))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment