Last active
December 10, 2023 23:47
-
-
Save mildsunrise/e21ae2b1649532813f2594932f9e9371 to your computer and use it in GitHub Desktop.
Integer (and polynomial) modular arithmetic for Python!
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
INTEGER MODULAR ARITHMETIC | |
These functions implement modular arithmetic-related functions (Z/nZ). | |
As an implied precondition, parameters are assumed to be integers unless otherwise noted. | |
This code is time-sensitive and thus NOT safe to use for online cryptography. | |
""" | |
from typing import Iterable, Tuple, NamedTuple | |
from functools import reduce | |
# descriptive aliases (assumed not to be negative) | |
Natural = int | |
class Congruence(NamedTuple): | |
""" an (x, m) tuple describing the congruence: a = x (mod m) """ | |
x: int | |
modulus: int | |
# FIXME: in constructor, assert modulus != 0 | |
def matches(self, a: int) -> bool: | |
""" check if an integer a satisfies this congruence """ | |
return congruent(a, self.x, self.modulus) | |
def normalized(self) -> 'Congruence': | |
""" return the equivalent normalized congruence """ | |
modulus = abs(self.modulus) | |
return Congruence(self.x % modulus, modulus) | |
def is_normalized(self) -> bool: | |
""" check if this congruence is normalized """ | |
return self.modulus > 0 and 0 <= self.x < self.modulus | |
@staticmethod | |
def trivial(): | |
""" return the normalized identity congruence (matches all integers) """ | |
return Congruence(0, 1) | |
# useful operations | |
def intersect(self, a: 'Congruence') -> 'Congruence': | |
""" Intersection of two congruences. | |
Returns the solution in the form of an x congruence. | |
Raises ValueError if there are no solutions. | |
Postcondition: If self.is_normalized() and b.is_normalized(), then x.is_normalized() """ | |
x = mod_div(self.x - a.x, a.modulus, self.modulus) | |
return Congruence(a.x + a.modulus * x.x, a.modulus * x.modulus) | |
@staticmethod | |
def intersection(xs: Iterable['Congruence']) -> 'Congruence': | |
""" Intersection of N congruences (see intersect()). """ | |
return reduce(Congruence.intersect, xs, Congruence.trivial()) | |
def gcd(a: int, b: int) -> int: | |
""" Euclidean algorithm (iterative). | |
Returns the Greatest Common Divisor of a and b. """ | |
while b: a, b = b, a % b | |
return a | |
def egcd(a: int, b: int) -> Tuple[int, int, int]: | |
""" Extended Euclidean algorithm (iterative). | |
Returns (d, x, y) where d is the Greatest Common Divisor of a and b. | |
x, y are integers that satisfy: a*x + b*y = d | |
Precondition: b != 0 | |
Postcondition: abs(x) <= abs(b//d) and abs(y) <= abs(a//d) """ | |
a = (a, 1, 0) | |
b = (b, 0, 1) | |
while True: | |
q, r = divmod(a[0], b[0]) | |
if not r: return b | |
a, b = b, (r, a[1] - q*b[1], a[2] - q*b[2]) | |
def mod_pow(x: int, exponent: Natural, modulus: int) -> int: | |
""" Modular exponentiation by squaring (iterative). | |
Returns: (x**exponent) % modulus | |
Precondition: exponent >= 0 and modulus > 0 | |
Precondition: not (x == 0 and exponent == 0) """ | |
factor = x % modulus; result = 1 | |
while exponent: | |
if exponent & 1: | |
result = (result * factor) % modulus | |
factor = (factor * factor) % modulus | |
exponent >>= 1 | |
return result | |
def mod_inv(a: int, modulus: int) -> int: | |
""" Modular multiplicative inverse. | |
Returns b so that: (a * b) % modulus == 1 | |
Precondition: modulus > 0 and coprime(a, modulus) | |
Postcondition: 0 < b < modulus """ | |
d, x, _ = egcd(a, modulus) | |
assert d == 1 # inverse exists | |
return x if x > 0 else x + modulus | |
def mod_div(a: int, b: int, modulus: int) -> Congruence: | |
""" Modular division (a / b). | |
Returns the solution in the form of an x congruence: | |
for any k, congruent(b * k, a, modulus) == x.matches(k) | |
Raises ValueError if there are no solutions. | |
Precondition: modulus > 0 | |
Postcondition: x.is_normalized() """ | |
d, x, _ = egcd(b, modulus) | |
q, r = divmod(a, d) | |
if r: raise ValueError('invalid modular division') | |
return Congruence(q * x, modulus // d).normalized() | |
def isqrt(n: int) -> int: | |
""" Integer square root (Newton's method) (copied from cpython). | |
Returns greatest x so that x*x <= n. | |
Precondition: n >= 0 """ | |
assert n >= 0 | |
if n == 0: return 0 | |
c = (n.bit_length() - 1) // 2 | |
a = 1 | |
d = 0 | |
for s in reversed(range(c.bit_length())): | |
e = d | |
d = c >> s | |
a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a | |
return a - (a*a > n) | |
def congruent(a: int, b: int, modulus: int) -> bool: | |
""" Checks if a is congruent with b under modulus. | |
Precondition: modulus > 0 """ | |
return (a-b) % modulus == 0 | |
def coprime(a: int, b: int) -> bool: | |
""" Checks if a and b are coprime. """ | |
return gcd(a, b) == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
BINARY POLYNOMIAL ARITHMETIC | |
These functions operate on binary polynomials (Z/2Z[x]), expressed as coefficient bitmasks, etc: | |
0b100111 -> x^5 + x^2 + x + 1 | |
As an implied precondition, parameters are assumed to be *nonnegative* integers unless otherwise noted. | |
This code is time-sensitive and thus NOT safe to use for online cryptography. | |
""" | |
from typing import Tuple | |
# descriptive aliases (assumed not to be negative) | |
Natural = int | |
BPolynomial = int | |
def p_mul(a: BPolynomial, b: BPolynomial) -> BPolynomial: | |
""" Binary polynomial multiplication (peasant). """ | |
result = 0 | |
while a and b: | |
if a & 1: result ^= b | |
a >>= 1; b <<= 1 | |
return result | |
def p_mod(a: BPolynomial, b: BPolynomial) -> BPolynomial: | |
""" Binary polynomial remainder / modulus. | |
Divides a by b and returns resulting remainder polynomial. | |
Precondition: b != 0 """ | |
bl = b.bit_length() | |
while True: | |
shift = a.bit_length() - bl | |
if shift < 0: return a | |
a ^= b << shift | |
def p_divmod(a: BPolynomial, b: BPolynomial) -> Tuple[BPolynomial, BPolynomial]: | |
""" Binary polynomial division. | |
Divides a by b and returns resulting (quotient, remainder) polynomials. | |
Precondition: b != 0 """ | |
q = 0; bl = b.bit_length() | |
while True: | |
shift = a.bit_length() - bl | |
if shift < 0: return (q, a) | |
q ^= 1 << shift; a ^= b << shift | |
def p_mod_mul(a: BPolynomial, b: BPolynomial, modulus: BPolynomial) -> BPolynomial: | |
""" Binary polynomial modular multiplication (peasant). | |
Returns p_mod(p_mul(a, b), modulus) | |
Precondition: modulus != 0 and b < modulus """ | |
result = 0; deg = p_degree(modulus) | |
assert p_degree(b) < deg | |
while a and b: | |
if a & 1: result ^= b | |
a >>= 1; b <<= 1 | |
if (b >> deg) & 1: b ^= modulus | |
return result | |
def p_exp(a: BPolynomial, exponent: Natural) -> BPolynomial: | |
""" Binary polynomial exponentiation by squaring (iterative). | |
Returns polynomial `a` multiplied by itself `exponent` times. | |
Precondition: not (x == 0 and exponent == 0) """ | |
factor = a; result = 1 | |
while exponent: | |
if exponent & 1: result = p_mul(result, factor) | |
factor = p_mul(factor, factor) | |
exponent >>= 1 | |
return result | |
def p_gcd(a: BPolynomial, b: BPolynomial) -> BPolynomial: | |
""" Binary polynomial euclidean algorithm (iterative). | |
Returns the Greatest Common Divisor of polynomials a and b. """ | |
while b: a, b = b, p_mod(a, b) | |
return a | |
def p_egcd(a: BPolynomial, b: BPolynomial) -> tuple[BPolynomial, BPolynomial, BPolynomial]: | |
""" Binary polynomial Extended Euclidean algorithm (iterative). | |
Returns (d, x, y) where d is the Greatest Common Divisor of polynomials a and b. | |
x, y are polynomials that satisfy: p_mul(a,x) ^ p_mul(b,y) = d | |
Precondition: b != 0 | |
Postcondition: x <= p_div(b,d) and y <= p_div(a,d) """ | |
a = (a, 1, 0) | |
b = (b, 0, 1) | |
while True: | |
q, r = p_divmod(a[0], b[0]) | |
if not r: return b | |
a, b = b, (r, a[1] ^ p_mul(q, b[1]), a[2] ^ p_mul(q, b[2])) | |
def p_mod_inv(a: BPolynomial, modulus: BPolynomial) -> BPolynomial: | |
""" Binary polynomial modular multiplicative inverse. | |
Returns b so that: p_mod(p_mul(a, b), modulus) == 1 | |
Precondition: modulus != 0 and p_coprime(a, modulus) | |
Postcondition: b < modulus """ | |
d, x, y = p_egcd(a, modulus) | |
assert d == 1 # inverse exists | |
return x | |
def p_mod_pow(x: BPolynomial, exponent: Natural, modulus: BPolynomial) -> BPolynomial: | |
""" Binary polynomial modular exponentiation by squaring (iterative). | |
Returns: p_mod(p_exp(x, exponent), modulus) | |
Precondition: modulus > 0 | |
Precondition: not (x == 0 and exponent == 0) """ | |
factor = x = p_mod(x, modulus); result = 1 | |
while exponent: | |
if exponent & 1: | |
result = p_mod_mul(result, factor, modulus) | |
factor = p_mod_mul(factor, factor, modulus) | |
exponent >>= 1 | |
return result | |
def p_degree(a: BPolynomial) -> int: | |
""" Returns degree of a. """ | |
return a.bit_length() - 1 | |
def p_congruent(a: BPolynomial, b: BPolynomial, modulus: BPolynomial) -> bool: | |
""" Checks if a is congruent with b under modulus. | |
Precondition: modulus > 0 """ | |
return p_mod(a^b, modulus) == 0 | |
def p_coprime(a: BPolynomial, b: BPolynomial) -> bool: | |
""" Checks if a and b are coprime. """ | |
return p_gcd(a, b) == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
UTILITIES | |
""" | |
from typing import Iterator, Iterable | |
def to_bits(n: int) -> Iterator[int]: | |
""" Generates the bits of n that are 1, in ascending order. """ | |
bit = 0 | |
while n: | |
if n & 1: yield bit | |
bit += 1 | |
n >>= 1 | |
def from_bits(bits: Iterable[int], strict=False) -> int: | |
""" Assembles a series of bits into an integer with these bits set to 1. | |
If a bit is negative, ValueError is raised. | |
If strict=True and there are duplicate bits, ValueError is raised. """ | |
n = 0 | |
for bit in bits: | |
mask = 1 << bit | |
if strict and (n & mask): | |
raise ValueError('duplicated bit') | |
n |= mask | |
return n | |
def reverse_bits(n, width): | |
""" Takes a `width`-bit int and returns it with the bits reversed. | |
Precondition: 0 <= n < (1 << width) """ | |
o = 0 | |
for _ in range(width): | |
o <<= 1 | |
o |= n & 1 | |
n >>= 1 | |
assert not n | |
return o | |
def polynomial_str(n: int, variable: str="x", unicode: bool=False, separator: str=" + ", constant: str="1") -> str: | |
""" Formats binary polynomial 'n' as a nice string, i.e. "x^10 + x^4 + x + 1". | |
If unicode=True, then superscript digits will be used instead of ^n notation. """ | |
sup = lambda s: "".join("⁰¹²³⁴⁵⁶⁷⁸⁹"[ord(c) & 0xF] for c in s) | |
power = lambda s: variable + sup(s) if unicode else variable + "^" + s | |
term = lambda bit: constant if bit == 0 else ( \ | |
variable if bit == 1 else power(str(bit)) ) | |
return separator.join(map(term, sorted(to_bits(n), reverse=True))) | |
bits = lambda *bits: from_bits(bits, strict=True) | |
bit_str = lambda n, width=0: "{:b}".format(abs(n)).rjust(width, "0") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
thank you @mildsunrise .