Skip to content

Instantly share code, notes, and snippets.

@Sam-Belliveau
Last active November 28, 2022 16:53
Show Gist options
  • Save Sam-Belliveau/e169eee11f666f43df53d1806192cc19 to your computer and use it in GitHub Desktop.
Save Sam-Belliveau/e169eee11f666f43df53d1806192cc19 to your computer and use it in GitHub Desktop.
from math import log2, floor, ceil, sqrt
# Class that represents a polynomial mod (x^r - 1, n)
class Polynomial:
# Create polynomial that equals 0 mod (x^r - 1, n)
def __init__(self, r, n, coeff=None):
self.r = r
self.n = n
self.coeff = [0] * r if coeff is None else coeff
# Add a coefficient to the polynomial mod (x^r - 1, n)
def add(self, power, coeff):
self.coeff[power % self.r] += coeff
self.coeff[power % self.r] %= self.n
return self
# Add two polynomials together mod (x^r - 1, n)
def __add__(a, b):
return Polynomial(a.r, a.n, [(ac + bc) % a.n
for ac, bc in zip(a.coeff, b.coeff)])
# Multiply two polynomials together mod (x^r - 1, n)
def __mul__(a, b):
return Polynomial(a.r, a.n, [(
sum((a.coeff[j] * b.coeff[000 + i - j]) for j in range(000, i + 1)) +
sum((a.coeff[j] * b.coeff[a.r + i - j]) for j in range(i + 1, a.r))
) % a.n for i in range(a.r)])
# Raise polynomial to the power mod (x^r - 1, n)
# Runtime: O(log2(pow) * r^2 * log2(n))
def __pow__(self, pow):
if pow == 0:
return Polynomial(self.r, self.n).add(0, 1)
o = self
pow -= 1
while pow > 0:
if pow % 2 == 1:
o *= self
self *= self
pow //= 2
return o
# Check if the two polynomials are equal
def __eq__(a, b):
return (a.r == b.r and a.n == b.n and a.coeff == b.coeff)
# Print the polynomial as a string
def __str__(self):
return " + ".join(f"{n}*X^{p}" if p != 0 else f"{n}"
for p, n in enumerate(self.coeff) if n != 0)
# Check if a number has an rth root
# Runtime: O(log2(n))
def is_nth_power(n, r):
base = 0
for b in reversed(range(1 + ceil(log2(n) / r))):
i = base + 2**b
result = i**r
if result < n:
base = i
elif result == n:
return True
return False
# This function determines if n is a power of another number
# Runtime: O(log2(n) ^ 2)
def is_power(n):
if n < 4: return False
return any(is_nth_power(n, i) for i in range(2, 1 + ceil(log2(n))))
# This function determines if a and b are coprime
# Runtime: O(log2(n))
def coprime(a, b):
while b:
a, b = b, a % b
return a == 1
# This function finds the smallest r such that the multiplicative
# order of n mod r is greater than or equal to log2(n)^2
# Runtime: O(log2(n)^5) [research may lower this in the future]
def find_smallest_r(n):
if n < 2: return 0
mr = max(3, ceil(log2(n)**5))
mk = floor(log2(n)**2)
for r in range(2, mr):
if coprime(n, r) and all(
(not pow(n, k, r) in [0, 1]) for k in range(1, mk + 1)):
return r
return mr - 1
# This is the totient function,
# It's implementation if O(r), which is not in P
# however r grows with O(log2(n)^5) at most, so that's fine
def totient(r):
return sum(1 for i in range(1, 1 + r) if coprime(i, r))
# This tells if a function is prime the traditional way
# However its not tractable, but it must be used in AKS for all n < 31
# Runtime: O(sqrt(n))
def factoring_prime_test(n):
if n < 2: return False
if n % 2 == 0: return n == 2
return not any(n % i == 0 for i in range(3, 1 + ceil(sqrt(n)), 2))
# This function determines if n is a prime number
# Runtime: O(log2(n)^15.5) <- not what the research paper said but it works
def aks_prime_test(n):
# O(1)
if n < 2:
return False
# O(log2(n)^2)
if is_power(n):
return False
# O(log2(n)^5)
r = find_smallest_r(n)
# O(log2(n)^5)
for a in range(2, min(n, r)):
if n % a == 0:
return False
# O(1)
if n <= r:
return True
# [ loop ] [exponentiation]
# O(log2(n)^3.5 * log2(n)^12 ) == O(log2(n)^15.5)
for a in range(1, floor(sqrt(totient(r)) * log2(n))):
if not coprime(a, n):
return False
# This side represents (X + a)^n
# O(r^2 * log2(n) * log2(n)) == O(log2(n)^12)
lhs = Polynomial(r, n).add(0, a).add(1, 1)**n
# This side represents X^n + a
rhs = Polynomial(r, n).add(0, a).add(n, 1)
# This checks if (X + a)^n != X^n + a mod (x^r - 1, n)
if lhs != rhs:
return False
# If all of the previous steps have checked out, then the number is prime
return True
# This is a test function to check if the algorithm is working
if __name__ == "__main__":
# Run a test on sets of 100 numbers at a time
for g in range(1000):
lower_bound = 100 * (g + 0)
upper_bound = 100 * (g + 1)
print(f"Testing {lower_bound}...{upper_bound - 1}: ")
print(f" - 0% Complete", end=f"{' '*10}\r")
passed = True
for i in range(lower_bound, upper_bound):
print(f" - {i - lower_bound}% Complete", end=f"{' '*10}\r")
factoring = factoring_prime_test(i)
aks = aks_prime_test(i)
# Useful debugging information
if factoring != aks:
print(f" - Failed! [{i}] {' ' * 20}")
print(f" - Factoring Prime Test: {factoring}")
print(f" - AKS Prime Test: {aks}")
passed = False
if passed:
print(f" - Passed! {' ' * 20}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment