Skip to content

Instantly share code, notes, and snippets.

@mildsunrise
Last active December 10, 2023 23:47
Show Gist options
  • Save mildsunrise/e21ae2b1649532813f2594932f9e9371 to your computer and use it in GitHub Desktop.
Save mildsunrise/e21ae2b1649532813f2594932f9e9371 to your computer and use it in GitHub Desktop.
Integer (and polynomial) modular arithmetic for Python!
"""
INTEGER MODULAR ARITHMETIC
These functions implement modular arithmetic-related functions (Z/nZ).
As an implied precondition, parameters are assumed to be integers unless otherwise noted.
This code is time-sensitive and thus NOT safe to use for online cryptography.
"""
from typing import Iterable, Tuple, NamedTuple
from functools import reduce
# descriptive aliases (assumed not to be negative)
Natural = int
class Congruence(NamedTuple):
""" an (x, m) tuple describing the congruence: a = x (mod m) """
x: int
modulus: int
# FIXME: in constructor, assert modulus != 0
def matches(self, a: int) -> bool:
""" check if an integer a satisfies this congruence """
return congruent(a, self.x, self.modulus)
def normalized(self) -> 'Congruence':
""" return the equivalent normalized congruence """
modulus = abs(self.modulus)
return Congruence(self.x % modulus, modulus)
def is_normalized(self) -> bool:
""" check if this congruence is normalized """
return self.modulus > 0 and 0 <= self.x < self.modulus
@staticmethod
def trivial():
""" return the normalized identity congruence (matches all integers) """
return Congruence(0, 1)
# useful operations
def intersect(self, a: 'Congruence') -> 'Congruence':
""" Intersection of two congruences.
Returns the solution in the form of an x congruence.
Raises ValueError if there are no solutions.
Postcondition: If self.is_normalized() and b.is_normalized(), then x.is_normalized() """
x = mod_div(self.x - a.x, a.modulus, self.modulus)
return Congruence(a.x + a.modulus * x.x, a.modulus * x.modulus)
@staticmethod
def intersection(xs: Iterable['Congruence']) -> 'Congruence':
""" Intersection of N congruences (see intersect()). """
return reduce(Congruence.intersect, xs, Congruence.trivial())
def gcd(a: int, b: int) -> int:
""" Euclidean algorithm (iterative).
Returns the Greatest Common Divisor of a and b. """
while b: a, b = b, a % b
return a
def egcd(a: int, b: int) -> Tuple[int, int, int]:
""" Extended Euclidean algorithm (iterative).
Returns (d, x, y) where d is the Greatest Common Divisor of a and b.
x, y are integers that satisfy: a*x + b*y = d
Precondition: b != 0
Postcondition: abs(x) <= abs(b//d) and abs(y) <= abs(a//d) """
a = (a, 1, 0)
b = (b, 0, 1)
while True:
q, r = divmod(a[0], b[0])
if not r: return b
a, b = b, (r, a[1] - q*b[1], a[2] - q*b[2])
def mod_pow(x: int, exponent: Natural, modulus: int) -> int:
""" Modular exponentiation by squaring (iterative).
Returns: (x**exponent) % modulus
Precondition: exponent >= 0 and modulus > 0
Precondition: not (x == 0 and exponent == 0) """
factor = x % modulus; result = 1
while exponent:
if exponent & 1:
result = (result * factor) % modulus
factor = (factor * factor) % modulus
exponent >>= 1
return result
def mod_inv(a: int, modulus: int) -> int:
""" Modular multiplicative inverse.
Returns b so that: (a * b) % modulus == 1
Precondition: modulus > 0 and coprime(a, modulus)
Postcondition: 0 < b < modulus """
d, x, _ = egcd(a, modulus)
assert d == 1 # inverse exists
return x if x > 0 else x + modulus
def mod_div(a: int, b: int, modulus: int) -> Congruence:
""" Modular division (a / b).
Returns the solution in the form of an x congruence:
for any k, congruent(b * k, a, modulus) == x.matches(k)
Raises ValueError if there are no solutions.
Precondition: modulus > 0
Postcondition: x.is_normalized() """
d, x, _ = egcd(b, modulus)
q, r = divmod(a, d)
if r: raise ValueError('invalid modular division')
return Congruence(q * x, modulus // d).normalized()
def isqrt(n: int) -> int:
""" Integer square root (Newton's method) (copied from cpython).
Returns greatest x so that x*x <= n.
Precondition: n >= 0 """
assert n >= 0
if n == 0: return 0
c = (n.bit_length() - 1) // 2
a = 1
d = 0
for s in reversed(range(c.bit_length())):
e = d
d = c >> s
a = (a << d - e - 1) + (n >> 2*c - e - d + 1) // a
return a - (a*a > n)
def congruent(a: int, b: int, modulus: int) -> bool:
""" Checks if a is congruent with b under modulus.
Precondition: modulus > 0 """
return (a-b) % modulus == 0
def coprime(a: int, b: int) -> bool:
""" Checks if a and b are coprime. """
return gcd(a, b) == 1
"""
BINARY POLYNOMIAL ARITHMETIC
These functions operate on binary polynomials (Z/2Z[x]), expressed as coefficient bitmasks, etc:
0b100111 -> x^5 + x^2 + x + 1
As an implied precondition, parameters are assumed to be *nonnegative* integers unless otherwise noted.
This code is time-sensitive and thus NOT safe to use for online cryptography.
"""
from typing import Tuple
# descriptive aliases (assumed not to be negative)
Natural = int
BPolynomial = int
def p_mul(a: BPolynomial, b: BPolynomial) -> BPolynomial:
""" Binary polynomial multiplication (peasant). """
result = 0
while a and b:
if a & 1: result ^= b
a >>= 1; b <<= 1
return result
def p_mod(a: BPolynomial, b: BPolynomial) -> BPolynomial:
""" Binary polynomial remainder / modulus.
Divides a by b and returns resulting remainder polynomial.
Precondition: b != 0 """
bl = b.bit_length()
while True:
shift = a.bit_length() - bl
if shift < 0: return a
a ^= b << shift
def p_divmod(a: BPolynomial, b: BPolynomial) -> Tuple[BPolynomial, BPolynomial]:
""" Binary polynomial division.
Divides a by b and returns resulting (quotient, remainder) polynomials.
Precondition: b != 0 """
q = 0; bl = b.bit_length()
while True:
shift = a.bit_length() - bl
if shift < 0: return (q, a)
q ^= 1 << shift; a ^= b << shift
def p_mod_mul(a: BPolynomial, b: BPolynomial, modulus: BPolynomial) -> BPolynomial:
""" Binary polynomial modular multiplication (peasant).
Returns p_mod(p_mul(a, b), modulus)
Precondition: modulus != 0 and b < modulus """
result = 0; deg = p_degree(modulus)
assert p_degree(b) < deg
while a and b:
if a & 1: result ^= b
a >>= 1; b <<= 1
if (b >> deg) & 1: b ^= modulus
return result
def p_exp(a: BPolynomial, exponent: Natural) -> BPolynomial:
""" Binary polynomial exponentiation by squaring (iterative).
Returns polynomial `a` multiplied by itself `exponent` times.
Precondition: not (x == 0 and exponent == 0) """
factor = a; result = 1
while exponent:
if exponent & 1: result = p_mul(result, factor)
factor = p_mul(factor, factor)
exponent >>= 1
return result
def p_gcd(a: BPolynomial, b: BPolynomial) -> BPolynomial:
""" Binary polynomial euclidean algorithm (iterative).
Returns the Greatest Common Divisor of polynomials a and b. """
while b: a, b = b, p_mod(a, b)
return a
def p_egcd(a: BPolynomial, b: BPolynomial) -> tuple[BPolynomial, BPolynomial, BPolynomial]:
""" Binary polynomial Extended Euclidean algorithm (iterative).
Returns (d, x, y) where d is the Greatest Common Divisor of polynomials a and b.
x, y are polynomials that satisfy: p_mul(a,x) ^ p_mul(b,y) = d
Precondition: b != 0
Postcondition: x <= p_div(b,d) and y <= p_div(a,d) """
a = (a, 1, 0)
b = (b, 0, 1)
while True:
q, r = p_divmod(a[0], b[0])
if not r: return b
a, b = b, (r, a[1] ^ p_mul(q, b[1]), a[2] ^ p_mul(q, b[2]))
def p_mod_inv(a: BPolynomial, modulus: BPolynomial) -> BPolynomial:
""" Binary polynomial modular multiplicative inverse.
Returns b so that: p_mod(p_mul(a, b), modulus) == 1
Precondition: modulus != 0 and p_coprime(a, modulus)
Postcondition: b < modulus """
d, x, y = p_egcd(a, modulus)
assert d == 1 # inverse exists
return x
def p_mod_pow(x: BPolynomial, exponent: Natural, modulus: BPolynomial) -> BPolynomial:
""" Binary polynomial modular exponentiation by squaring (iterative).
Returns: p_mod(p_exp(x, exponent), modulus)
Precondition: modulus > 0
Precondition: not (x == 0 and exponent == 0) """
factor = x = p_mod(x, modulus); result = 1
while exponent:
if exponent & 1:
result = p_mod_mul(result, factor, modulus)
factor = p_mod_mul(factor, factor, modulus)
exponent >>= 1
return result
def p_degree(a: BPolynomial) -> int:
""" Returns degree of a. """
return a.bit_length() - 1
def p_congruent(a: BPolynomial, b: BPolynomial, modulus: BPolynomial) -> bool:
""" Checks if a is congruent with b under modulus.
Precondition: modulus > 0 """
return p_mod(a^b, modulus) == 0
def p_coprime(a: BPolynomial, b: BPolynomial) -> bool:
""" Checks if a and b are coprime. """
return p_gcd(a, b) == 1
"""
UTILITIES
"""
from typing import Iterator, Iterable
def to_bits(n: int) -> Iterator[int]:
""" Generates the bits of n that are 1, in ascending order. """
bit = 0
while n:
if n & 1: yield bit
bit += 1
n >>= 1
def from_bits(bits: Iterable[int], strict=False) -> int:
""" Assembles a series of bits into an integer with these bits set to 1.
If a bit is negative, ValueError is raised.
If strict=True and there are duplicate bits, ValueError is raised. """
n = 0
for bit in bits:
mask = 1 << bit
if strict and (n & mask):
raise ValueError('duplicated bit')
n |= mask
return n
def reverse_bits(n, width):
""" Takes a `width`-bit int and returns it with the bits reversed.
Precondition: 0 <= n < (1 << width) """
o = 0
for _ in range(width):
o <<= 1
o |= n & 1
n >>= 1
assert not n
return o
def polynomial_str(n: int, variable: str="x", unicode: bool=False, separator: str=" + ", constant: str="1") -> str:
""" Formats binary polynomial 'n' as a nice string, i.e. "x^10 + x^4 + x + 1".
If unicode=True, then superscript digits will be used instead of ^n notation. """
sup = lambda s: "".join("⁰¹²³⁴⁵⁶⁷⁸⁹"[ord(c) & 0xF] for c in s)
power = lambda s: variable + sup(s) if unicode else variable + "^" + s
term = lambda bit: constant if bit == 0 else ( \
variable if bit == 1 else power(str(bit)) )
return separator.join(map(term, sorted(to_bits(n), reverse=True)))
bits = lambda *bits: from_bits(bits, strict=True)
bit_str = lambda n, width=0: "{:b}".format(abs(n)).rjust(width, "0")
@Anupam9830
Copy link

thank you @mildsunrise .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment