Skip to content

Instantly share code, notes, and snippets.

@j2kun
Last active November 29, 2022 17:01
Show Gist options
  • Save j2kun/49a52731e05f3247ab6f9519ac193aa3 to your computer and use it in GitHub Desktop.
Save j2kun/49a52731e05f3247ab6f9519ac193aa3 to your computer and use it in GitHub Desktop.
Broken negacyclic polynomial multiplication based on nufhe docs https://github.com/nucypher/nufhe/blob/master/doc/source/implementation_details.rst
import numpy as np
import math
def primitive_nth_root(n):
"""Return a primitive nth root of unity."""
return math.cos(2 * math.pi / n) + 1.0j * math.sin(2 * math.pi / n)
def poly_mul(a, b):
n = a.shape[0]
primitive_root = primitive_nth_root(2 * n)
root_powers = primitive_root**np.arange(n // 2)
a_preprocessed = (a[:n // 2] - 1j * a[n // 2:]) * root_powers
b_preprocessed = (b[:n // 2] - 1j * b[n // 2:]) * root_powers
a_ft = np.fft.fft(a_preprocessed)
b_ft = np.fft.fft(b_preprocessed)
prod = a_ft * b_ft
ifft_prod = np.conj(np.fft.ifft(prod))
ifft_rotated = ifft_prod * root_powers
first_half = np.real(ifft_rotated)
second_half = np.imag(ifft_rotated)
return np.round(np.concatenate([first_half, second_half])).astype(a.dtype)
def _np_polymul(poly1, poly2):
# poly_mod represents the polynomial to divide by: x^N + 1, N = len(a)
poly_mod = np.zeros(len(poly1) + 1, np.uint32)
poly_mod[0] = 1
poly_mod[len(poly1)] = 1
# Reversing the list order because numpy polymul interprets the polynomial
# with higher-order coefficients first, whereas our code does the opposite
np_mul = np.polymul(list(reversed(poly1)), list(reversed(poly2)))
(_, np_poly_mod) = np.polydiv(np_mul, poly_mod)
np_pad = np.pad(
np_poly_mod, (len(poly1) - len(np_poly_mod), 0),
"constant",
constant_values=(0, 0))
return np.array(list(reversed(np_pad)), dtype=int)
if __name__ == "__main__":
a = np.array([1, 2, 3, 4])
b = np.array([2, 3, 4, 5])
# a = np.random.randint(low=0, high=2**16 - 1, size=(512,))
# b = np.random.randint(low=0, high=2**16 - 1, size=(512,))
output = poly_mul(a, b)
expected = _np_polymul(a, b)
abs_diff = np.abs(output - expected)
print(f"output=\t\t{output}")
print(f"expected=\t{expected}")
print(f"max_abs_diff=\t{np.max(abs_diff)}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment