-
-
Save silky/c8497b91c11e2a140e30f8eabf8f90d8 to your computer and use it in GitHub Desktop.
Count solutions to linear Diophantine equations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
from fractions import Fraction | |
def _gcd(a, b): | |
while a: | |
a, b = b % a, a | |
return b | |
def _mod_inv(a: int, m: int) -> int: | |
x, y = a, m | |
u0, u1 = 0, 1 | |
v0, v1 = 1, 0 | |
while y: | |
x, (q, y) = y, divmod(x, y) | |
u0, u1 = u1 - q * u0, u0 | |
v0, v1 = v1 - q * v0, v0 | |
if x != 1: | |
raise ValueError | |
return u1 * (-1 if a < 0 else 1) % m | |
def _factors(n): | |
# TODO Prime factorisation and Cartesian products | |
for i in range(1, n//2 + 1): | |
if n % i == 0: | |
yield i | |
yield n | |
def linear_diophantine_counts_for(a, modulus = None): | |
""" | |
Returns a function which maps n to the number of solutions of a.x = n in natural numbers. | |
>>> foo = linear_diophantine_counts_for([1, 2, 6]) | |
>>> foo(10) | |
9 | |
>>> foo(100) | |
459 | |
>>> foo(1000) | |
42084 | |
>>> foo(10000) | |
4170834 | |
>>> bar = linear_diophantine_counts_for([1, 2, 6], 17) | |
>>> bar(10) | |
9 | |
>>> bar(100) | |
0 | |
>>> bar(1000) | |
9 | |
>>> bar(10000) | |
3 | |
""" | |
gcd_a = a[0] | |
for a_i in a[1:]: | |
gcd_a = _gcd(a_i, gcd_a) | |
a = [a_i // gcd_a for a_i in a] | |
# Let f(z) = \prod_{a_i in a} (1 - z^{a_i})^{-1} | |
# Then we want to be able to evaluate [z^n]f(z) | |
# Start by decomposing into a partial fraction whose denominators are powers of cyclotomics. | |
cyclotomic_frequencies = Counter() | |
for a_i in a: | |
for factor in _factors(a_i): | |
cyclotomic_frequencies[factor] += 1 | |
# We can prove that, after M base cases, the partial fraction gives a quasi-polynomial with period lcm(a) and degree len(a). | |
# For details see http://cheddarmonk.org/papers/linear-diophantine-equations.pdf | |
M = max((w_d - 1) * d for d, w_d in cyclotomic_frequencies.items()) | |
period = 1 | |
for a_i in a: | |
period = period * (a_i // _gcd(period, a_i)) | |
degree_inc = len(a) # degree of quasi-polynomial + 1 | |
specials = M + period * degree_inc | |
precalc = [0] * specials | |
precalc[0] = 1 | |
for a_i in a: | |
for i in range(a_i, specials): | |
precalc[i] += precalc[i - a_i] | |
if modulus is not None and precalc[i] >= modulus: | |
precalc[i] -= modulus | |
# The point is to precalculate as much as possible, so expand out to polynomial coefficients. | |
polys = [[0] * degree_inc for _ in range(period)] | |
denominators = [1] * period # Only used when modulus is None | |
for n in range(period): | |
# Lagrange interpolation | |
# We want the last degree_inc indices and values from precalc where the index equals n (mod period) | |
# Here Python's unusually sane % behaviour is a boon | |
l = M + (n - M) % period | |
lagrange_points = [(x, precalc[x]) for x in range(l, specials, period)] | |
if modulus is None: | |
# We really have to work in rationals. Use the built-in ones. | |
for x, y in lagrange_points: | |
term = [0] * degree_inc | |
term[0] = y | |
for x2, y2 in lagrange_points: | |
if x == x2: | |
continue | |
# term = term * (z - x2) / (x - x2) | |
for i in range(degree_inc - 1, 0, -1): | |
term[i] = (term[i-1] - x2 * term[i]) / Fraction(x - x2) | |
term[0] = term[0] * Fraction(-x2, x - x2) | |
for i in range(degree_inc): | |
polys[n][i] += term[i] | |
# Although we used rationals for laziness, when it comes to evaluating in the callback | |
# we probably want to stick to integers. | |
lcm_denom = 1 | |
for coeff in polys[n]: | |
lcm_denom = coeff.denominator * (lcm_denom // _gcd(coeff.denominator, lcm_denom)) | |
denominators[n] = lcm_denom | |
for i in range(degree_inc): | |
polys[n][i] = polys[n][i].numerator * (lcm_denom // polys[n][i].denominator) | |
else: | |
for x, y in lagrange_points: | |
term = [0] * degree_inc | |
term[0] = y | |
denom = 1 | |
for x2, y2 in lagrange_points: | |
if x == x2: | |
continue | |
# term = term * (z - x2) | |
for i in range(degree_inc - 1, 0, -1): | |
term[i] = (term[i-1] - x2 * term[i]) % modulus | |
term[0] = term[0] * -x2 % modulus | |
# Optimise by only doing one _mod_inv | |
denom = denom * (x - x2) % modulus | |
recip = _mod_inv(denom, modulus) | |
for i in range(degree_inc): | |
polys[n][i] = (polys[n][i] + term[i] * recip) % modulus | |
def count(n): | |
if n < 0 or n % gcd_a: | |
return 0 | |
n //= gcd_a | |
if n < len(precalc): | |
return precalc[n] | |
if modulus is None: | |
rv = 0 | |
for coeff in reversed(polys[n % period]): | |
rv = rv * n + coeff | |
return rv // denominators[n % period] | |
else: | |
rv = 0 | |
poly = polys[n % period] | |
n %= modulus | |
for coeff in reversed(poly): | |
rv = (rv * n + coeff) % modulus | |
return rv | |
return count | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment