Created
April 19, 2019 10:24
-
-
Save bivoje/ad7d9c57c3ecb4673bf6dea6bbafe0b3 to your computer and use it in GitHub Desktop.
python code that calculates (and demonstrates) big number summation using Chinese Remainder Theorem
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3.6 | |
# https://stackoverflow.com/a/8748146 | |
_ = lambda x: x | |
# =================== | |
# flow | |
# =================== | |
def hmm(): | |
print("") | |
def then(): | |
print("\nthen,") | |
def by(theorem): | |
print(f'by {theorem}') | |
# =================== | |
# declaration | |
# =================== | |
#def for_variable(name, val): | |
# print("for {name} = {val},") | |
# globals()[name] = val | |
def declaration(vardic): | |
glob = globals() | |
for name, val in vardic.items(): | |
glob[name] = val | |
return ", ".join([f'{k} = {v}' for (k,v) in vardic.items()]) | |
def for_variables(**it): | |
print(f'for {declaration(it)}') | |
def using_variables(**it): | |
print(f'using {declaration(it)}') | |
def assign_wrap(f): | |
def func(**it): | |
name, vals = list(it.items())[0] | |
ret = f(*vals) | |
globals()[name] = ret | |
print(f'as {name} = {ret}') | |
return func | |
# =================== | |
# number theory | |
# =================== | |
# TODO remove n, clousre f should manage | |
def exGCD(r0, r1, s0, s1, t0, t1, f=None, n=1): | |
q1 = r0 // r1 | |
if f: f(n, q1, r1, s1, t1) | |
#r2 = r0 % r1 | |
r2 = r0 - r1 * q1 | |
if r2 == 0: # base | |
return (r1, s1, t1) | |
s2 = s0 - s1 * q1 | |
t2 = t0 - t1 * q1 | |
return exGCD(r1, r2, s1, s2, t1, t2, f, n+1) | |
def gcd(a, b): | |
d, s, t = exGCD(a, b, 1, 0, 0, 1) | |
return d | |
def coprime(a, b): | |
return gcd(a, b) == 1 | |
# whether pairwise relative prime | |
def coprimes(*ns): | |
assert(ns) # not empty | |
for i in range(len(ns)): | |
for j in range(i, len(ns)): | |
if not coprime(ns[i], ns[j]): | |
return False | |
return True | |
def mgcd(ns, f=None): | |
assert(ns) # not empty | |
if len(ns) == 1: | |
return ns[0] | |
mid = len(ns) // 2 | |
lgcd = mgcd(ns[:mid], f) | |
rgcd = mgcd(ns[mid:], f) | |
ret = gcd(lgcd, rgcd) | |
if f: f(ns, ret) | |
return ret | |
def muprimes(*ns): | |
return mgcd(ns) == 1 | |
def calcBezout(a, b): | |
def printExGCD(n, q, r, s, t): | |
print(f'>> {n:>3} {q:>5} {r:>5} {s:>5} {t:>5}') | |
print(f'calculating gcd({a},{b}) = {a} s + {b} t') | |
printExGCD("n", "q", "r", "s", "t") | |
printExGCD( 0, 0, a, 1, 0) | |
d, s, t = exGCD(a, b, 1, 0, 0, 1, printExGCD) | |
print(f'>> gcd({a},{b}) = {d} = {s} a + {t} b') | |
return d, s, t | |
from math import log2, floor | |
def calcGCD(*ns): | |
assert(ns) | |
level = 0 | |
max_level = ceil(log2(len(ns))) | |
gs = [[]] * len(ns) / 2 | |
#def accum_g(ns, g): | |
# level | |
#print(f'calculating gcd({*ns}') | |
#mgcd(ns, lambda ns, ret: print(f'> gcd({ns}) = {ret}')) | |
def zipWith(f, *lss): | |
return [ f(*ss) for ss in zip(*lss) ] | |
def mod_op(op, ms, *lss): | |
return zipWith(lambda m,x: x%m, ms, | |
zipWith(op, *lss)) | |
def int2crt(ms, n): | |
return mod_op(_, ms, [n] * len(ms)) | |
def convert_int2crt_(ms, n): | |
print(f'convert {n} into CRT tuple,') | |
return int2crt(ms, n) | |
convert_int2crt = assign_wrap(convert_int2crt_) | |
def mod_sum(ms, *lss): | |
return mod_op(lambda *args: sum(args), ms, *lss) | |
from operator import neg | |
def mod_neg(ms, xs): | |
return mod_op(neg, ms, xs) | |
def mod_inv(m, x): | |
d, s, t = exGCD(m, x, 1, 0, 0, 1) | |
assert(d == 1) | |
return t % m | |
# remainder operator (%) in python always returns positive | |
from operator import mul, add | |
from functools import reduce | |
from math import log10 | |
def crt2int(ms, xs): | |
M_ = reduce(mul, ms) | |
Ms = [ M_ // m for m in ms ] | |
xs_width = floor(max(map(log10, xs))) + 1 | |
Ms_width = floor(max(map(log10, Ms))) + 1 | |
ms_width = floor(max(map(log10, ms))) + 1 | |
def printrow(x, M, m, term=None): | |
print(f'{str(x).rjust(xs_width)} * ' + | |
f'{str(M).rjust(Ms_width)} * ' + | |
f'{str(m).ljust(ms_width)}' + | |
(f' = {str(term)}' if term else "")) | |
printrow("xs", "Ms", "Ms-1") | |
def digest(x, M, m): | |
m_ = mod_inv(m,M) | |
ret = x * M * m_ | |
printrow(x, M, m_, ret) | |
return ret | |
sums = reduce(add, zipWith(digest, xs, Ms, ms)) | |
ret = sums % M_ | |
print(f' => {sums} mod {M_} = {ret}') | |
return ret | |
def convert_crt2int_(ms, xs): | |
print(f'convert {xs} to integer') | |
return crt2int(ms, xs) | |
convert_crt2int = assign_wrap(convert_crt2int_) | |
# =================== | |
# algorithm | |
# =================== | |
for_variables( | |
x = 12345678901234567890, | |
y = 98765432109876543210) | |
using_variables( | |
#ms = [10007,10009,10037,10039,10061]) | |
ms = [10007,10009,10037,10039,10061,10067]) | |
hmm() | |
convert_int2crt( | |
xs = (ms, x)) | |
hmm() | |
convert_int2crt( | |
ys = (ms, y)) | |
then() | |
sums = mod_sum(ms, xs, ys) | |
subs = mod_sum(ms, mod_neg(ms, xs), ys) | |
convert_crt2int( | |
sum_int = (ms, sums)) | |
hmm() | |
convert_crt2int( | |
subs = (ms, subs)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment