Skip to content

Instantly share code, notes, and snippets.

@MikuroXina
Last active October 30, 2022 06:11
Show Gist options
  • Select an option

  • Save MikuroXina/d5f593aea5aee1bd6ab9541e28b4b9fe to your computer and use it in GitHub Desktop.

Select an option

Save MikuroXina/d5f593aea5aee1bd6ab9541e28b4b9fe to your computer and use it in GitHub Desktop.
An integer modulo 998244353 with Montgomery Multiplication and Numeric Theory Transformation.
use num::{traits::Pow, One, Zero};
use serde::{Deserialize, Serialize};
const R: u64 = 1 << 32;
/// Find `modulo_inv` which satisifes `modulo * modulo_inv ≡ -1 (mod R)`.
const fn find_neg_inv(modulo: u32) -> u32 {
let mut inv_mod = 0u32;
let mut t = 0;
let mut i = 1u32;
loop {
if t % 2 == 0 {
t += modulo;
inv_mod = inv_mod.wrapping_add(i);
}
t /= 2;
if let Some(next_i) = i.checked_mul(2) {
i = next_i;
} else {
break;
}
}
inv_mod
}
const fn find_r2(modulo: u32) -> u32 {
let modulo = modulo as u64;
let r = R % modulo;
(r * r % modulo) as u32
}
pub type ModInt998244353 = ModInt<998244353>;
#[test]
fn const_test_998244353() {
assert_eq!(ModInt998244353::N, 0x3B800001);
assert_eq!(ModInt998244353::N_PRIME, 0x3B7FFFFF);
assert_eq!(ModInt998244353::R2, 0x378DFBC6);
}
/// The Montgomery Form modulo `MOD`, multiplied `2^32`. `MOD` is expected to a prime number.
#[derive(
Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize,
)]
#[repr(transparent)]
pub struct ModInt<const MOD: u32>(u32);
impl<const MOD: u32> From<ModInt<MOD>> for u32 {
#[inline]
fn from(mod_int: ModInt<MOD>) -> Self {
mod_int.as_u32()
}
}
impl<const MOD: u32> ModInt<MOD> {
#[inline]
pub fn new(n: u32) -> Self {
Self(Self::reduce(n as u64 * Self::R2 as u64))
}
#[inline]
pub fn as_u32(self) -> u32 {
Self::reduce(self.0 as u64)
}
/// The modulo integer, which must be a prime number.
pub const N: u32 = MOD;
/// The modular inverse of `N`, which satisifies `N * N_PRIME ≡ -1`.
pub const N_PRIME: u32 = find_neg_inv(MOD);
/// The squared number of the multiplier `R ≡ 2^32` for Montgomery Form, which satisfies `R2 ≡ 2^64 (mod N)`.
pub const R2: u32 = find_r2(MOD);
#[inline]
pub fn inv(self) -> Self {
self.pow(MOD - 2)
}
#[inline]
pub fn reduce(x: u64) -> u32 {
let modulo = MOD as u64;
debug_assert!(x < modulo * R as u64);
let x_n_prime = (x as u32).wrapping_mul(Self::N_PRIME) as u64;
let mul = (x + x_n_prime * modulo) / R;
let ret = if modulo <= mul { mul - modulo } else { mul };
debug_assert!(ret < modulo);
ret as u32
}
}
macro_rules! impl_from_for_mod_int {
($t:ty) => {
impl<const MOD: u32> From<$t> for ModInt<MOD> {
#[inline]
fn from(x: $t) -> Self {
Self::new(x as u32)
}
}
impl<const MOD: u32> From<&'_ $t> for ModInt<MOD> {
#[inline]
fn from(&x: &'_ $t) -> Self {
Self::new(x as u32)
}
}
};
}
impl_from_for_mod_int!(u64);
impl_from_for_mod_int!(u32);
impl_from_for_mod_int!(i32);
impl<const MOD: u32> std::ops::Add for ModInt<MOD> {
type Output = Self;
#[inline]
fn add(mut self, rhs: Self) -> Self::Output {
self += rhs;
self
}
}
impl<const MOD: u32> std::ops::AddAssign for ModInt<MOD> {
#[inline]
fn add_assign(&mut self, rhs: Self) {
self.0 += rhs.0;
if MOD <= self.0 {
self.0 -= MOD;
}
}
}
impl<const MOD: u32> std::ops::Sub for ModInt<MOD> {
type Output = Self;
#[inline]
fn sub(mut self, rhs: Self) -> Self::Output {
self -= rhs;
self
}
}
impl<const MOD: u32> std::ops::SubAssign for ModInt<MOD> {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
if let Some(sub) = self.0.checked_sub(rhs.0) {
self.0 = sub;
} else {
self.0 += MOD;
self.0 -= rhs.0;
}
}
}
impl<const MOD: u32> std::ops::Mul for ModInt<MOD> {
type Output = Self;
#[inline]
fn mul(mut self, rhs: Self) -> Self::Output {
self *= rhs;
self
}
}
impl<const MOD: u32> std::ops::MulAssign for ModInt<MOD> {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
self.0 = Self::reduce(self.0 as u64 * rhs.0 as u64);
}
}
impl<const MOD: u32> std::ops::Div for ModInt<MOD> {
type Output = Self;
#[inline]
fn div(mut self, rhs: Self) -> Self::Output {
self /= rhs;
self
}
}
impl<const MOD: u32> std::ops::DivAssign for ModInt<MOD> {
#[inline]
#[allow(clippy::suspicious_op_assign_impl)]
fn div_assign(&mut self, rhs: Self) {
*self *= rhs.inv();
}
}
impl<const MOD: u32> std::ops::Neg for ModInt<MOD> {
type Output = Self;
#[inline]
fn neg(self) -> Self::Output {
Self(0) - self
}
}
impl<const MOD: u32> std::iter::Sum for ModInt<MOD> {
#[inline]
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = Self>,
{
iter.fold(Self(0), |a, b| a + b)
}
}
impl<const MOD: u32> std::iter::Product for ModInt<MOD> {
#[inline]
fn product<I>(iter: I) -> Self
where
I: Iterator<Item = Self>,
{
iter.fold(Self::new(1), |a, b| a * b)
}
}
impl<const MOD: u32> Zero for ModInt<MOD> {
#[inline]
fn zero() -> Self {
Self(0)
}
#[inline]
fn is_zero(&self) -> bool {
self.0 == 0
}
}
impl<const MOD: u32> One for ModInt<MOD> {
#[inline]
fn one() -> Self {
Self::new(1)
}
}
impl<const MOD: u32> Pow<u32> for ModInt<MOD> {
type Output = Self;
#[inline]
fn pow(mut self, mut exp: u32) -> Self::Output {
if exp == 0 {
return Self::new(1);
}
let mut y = Self::new(1);
while 0 < exp {
if exp % 2 == 1 {
y *= self;
}
self *= self;
exp /= 2;
}
y
}
}
use bytemuck::cast;
use wide::{i32x4, i32x8, i64x2, i64x4, u32x4, u32x8, CmpGt};
#[inline]
pub fn mul_u32x4(a: u32x4, b: u32x4, inv_n: u32x4, m1: u32x4) -> u32x4 {
cast(
cast::<_, i32x4>(mul_hi_u32x4(a, b)) + cast::<_, i32x4>(m1)
- cast::<_, i32x4>(mul_hi_u32x4(a * b * inv_n, m1)),
)
}
#[inline]
fn mul_hi_u32x4(a: u32x4, b: u32x4) -> u32x4 {
let a_inner = cast::<_, i32x4>(a).to_array();
let a13 = i32x4::new([a_inner[1], a_inner[1], a_inner[3], a_inner[3]]);
let b_inner = cast::<_, i32x4>(b).to_array();
let b13 = i32x4::new([b_inner[1], b_inner[1], b_inner[3], b_inner[3]]);
let prod02 = cast::<_, i32x4>(a * b).to_array();
let prod13 = cast::<_, i32x4>(a13 * b13).to_array();
let prod_lo =
cast::<_, i64x2>(i32x4::new([prod02[0], prod13[0], prod02[1], prod13[1]])).to_array();
let prod_hi =
cast::<_, i64x2>(i32x4::new([prod02[2], prod13[2], prod02[3], prod13[3]])).to_array();
cast(i64x2::new([prod_lo[1], prod_hi[1]]))
}
#[inline]
pub fn add_u32x4(a: u32x4, b: u32x4, m2: u32x4, m0: u32x4) -> u32x4 {
let ret = cast::<_, i32x4>(a) + cast::<_, i32x4>(b) - cast::<_, i32x4>(m2);
cast(cast::<_, i32x4>(cast::<_, u32x4>(cast::<_, i32x4>(m0).cmp_gt(ret)) & m2) + ret)
}
#[inline]
pub fn sub_u32x4(a: u32x4, b: u32x4, m2: u32x4, m0: u32x4) -> u32x4 {
let ret = cast::<_, i32x4>(a) - cast::<_, i32x4>(b);
cast(cast::<_, i32x4>(cast::<_, u32x4>(cast::<_, i32x4>(m0).cmp_gt(ret)) & m2) + ret)
}
#[inline]
pub fn mul_u32x8(a: u32x8, b: u32x8, inv_n: u32x8, m1: u32x8) -> u32x8 {
cast(
cast::<_, i32x8>(mul_hi_u32x8(a, b)) + cast::<_, i32x8>(m1)
- cast::<_, i32x8>(mul_hi_u32x8(a * b * inv_n, m1)),
)
}
#[inline]
fn mul_hi_u32x8(a: u32x8, b: u32x8) -> u32x8 {
let a_inner = cast::<_, i32x8>(a).to_array();
let a13 = i32x8::new([
a_inner[1], a_inner[1], a_inner[3], a_inner[3], a_inner[5], a_inner[5], a_inner[7],
a_inner[7],
]);
let b_inner = cast::<_, i32x8>(b).to_array();
let b13 = i32x8::new([
b_inner[1], b_inner[1], b_inner[3], b_inner[3], b_inner[5], b_inner[5], b_inner[7],
b_inner[7],
]);
let prod02 = cast::<_, i32x8>(a * b).to_array();
let prod13 = cast::<_, i32x8>(a13 * b13).to_array();
let prod_lo = cast::<_, i64x4>(i32x8::new([
prod02[0], prod13[0], prod02[1], prod13[1], prod02[2], prod13[2], prod02[3], prod13[3],
]))
.to_array();
let prod_hi = cast::<_, i64x4>(i32x8::new([
prod02[4], prod13[4], prod02[5], prod13[5], prod02[6], prod13[6], prod02[7], prod13[7],
]))
.to_array();
cast(i64x4::new([prod_lo[1], prod_hi[1], prod_lo[3], prod_hi[3]]))
}
#[inline]
pub fn add_u32x8(a: u32x8, b: u32x8, m2: u32x8, m0: u32x8) -> u32x8 {
let ret = cast::<_, i32x8>(a) + cast::<_, i32x8>(b) - cast::<_, i32x8>(m2);
cast(cast::<_, i32x8>(cast::<_, u32x8>(cast::<_, i32x8>(m0).cmp_gt(ret)) & m2) + ret)
}
#[inline]
pub fn sub_u32x8(a: u32x8, b: u32x8, m2: u32x8, m0: u32x8) -> u32x8 {
let ret = cast::<_, i32x8>(a) - cast::<_, i32x8>(b);
cast(cast::<_, i32x8>(cast::<_, u32x8>(cast::<_, i32x8>(m0).cmp_gt(ret)) & m2) + ret)
}
/// The maximum n-ary of the component can be transformed.
const LEVEL: usize = (ModInt998244353::N - 1).trailing_zeros() as usize;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Ntt {
primitive_root: ModInt998244353,
d_w: [ModInt998244353; LEVEL],
d_inv_w: [ModInt998244353; LEVEL],
}
impl Ntt {
pub fn new() -> Self {
let modulo = ModInt998244353::N;
let primitive_root = primitive_root(modulo);
let mut w = [ModInt998244353::default(); LEVEL];
let mut d_w = [ModInt998244353::default(); LEVEL];
let mut inv_w = [ModInt998244353::default(); LEVEL];
let mut d_inv_w = [ModInt998244353::default(); LEVEL];
w[LEVEL - 1] = primitive_root.pow((modulo - 1) / (1 << LEVEL));
inv_w[LEVEL - 1] = w[LEVEL - 1].inv();
for i in (0..LEVEL - 1).rev() {
w[i] = w[i + 1] * w[i + 1];
inv_w[i] = inv_w[i + 1] * inv_w[i + 1];
}
d_w[0] = w[1] * w[1];
d_inv_w[0] = d_w[0];
d_w[1] = w[1];
d_inv_w[1] = inv_w[1];
d_w[2] = w[2];
d_inv_w[2] = w[2];
for i in 3..LEVEL {
d_w[i] = d_w[i - 1] * inv_w[i - 2] * w[i];
d_inv_w[i] = d_inv_w[i - 1] * w[i - 2] * inv_w[i];
}
Self {
primitive_root,
d_w,
d_inv_w,
}
}
pub fn transform(&self, vec: &mut AudioVec) {
if vec.is_empty() {
return;
}
let k = vec.len().trailing_zeros();
if k == 1 {
let a = vec[1];
vec[1] = vec[0] - vec[1];
vec[0] += a;
return;
}
if k % 2 != 0 {
let v = 1 << (k - 1);
if v < 8 {
for j in 0..v {
let jv = vec[j + v];
vec[j + v] = vec[j] - jv;
vec[j] += jv;
}
} else {
let m0 = u32x8::default();
let m2 = u32x8::splat(2 * ModInt998244353::N);
for j0 in (0..v).step_by(8) {
let j1 = j0 + v;
let t0 = vec.get_u32x8(j0);
let t1 = vec.get_u32x8(j1);
let naj = montgomery::add_u32x8(t0, t1, m2, m0);
let naj_v = montgomery::sub_u32x8(t0, t1, m2, m0);
vec.set_u32x8(j0, naj);
vec.set_u32x8(j1, naj_v);
}
}
}
let mut u = 1 << (2 + (k % 2));
let mut v = 1 << (k - 2 - (k % 2));
let one = ModInt998244353::new(1);
let im = self.d_w[1];
while v != 0 {
if v == 1 {
let mut xx = one;
for jh in (0..u).step_by(4) {
let ww = xx * xx;
let wx = ww * xx;
let t0 = vec[jh];
let t1 = vec[jh + 1] * xx;
let t2 = vec[jh + 2] * ww;
let t3 = vec[jh + 3] * wx;
let t0p2 = t0 + t2;
let t1p3 = t1 + t3;
let t0m2 = t0 - t2;
let t1m3 = (t1 - t3) * im;
vec[jh] = t0p2 + t1p3;
vec[jh + 1] = t0p2 - t1p3;
vec[jh + 2] = t0m2 + t1m3;
vec[jh + 3] = t0m2 - t1m3;
xx *= self.d_w[(jh + 4).trailing_zeros() as usize];
}
} else if v == 4 {
let mut xx = one;
let m0 = u32x4::default();
let m1 = u32x4::splat(ModInt998244353::N);
let m2 = u32x4::splat(2 * ModInt998244353::N);
let inv_mod = u32x4::splat(ModInt998244353::N_PRIME);
let im = u32x4::splat(im.as_u32());
for jh in (0..u).step_by(4) {
if jh == 0 {
for j0 in (0..v).step_by(4) {
let j1 = j0 + v;
let j2 = j1 + v;
let j3 = j2 + v;
let t0 = vec.get_u32x4(j0);
let t1 = vec.get_u32x4(j1);
let t2 = vec.get_u32x4(j2);
let t3 = vec.get_u32x4(j3);
let t0p2 = montgomery::add_u32x4(t0, t2, m2, m0);
let t1p3 = montgomery::add_u32x4(t1, t3, m2, m0);
let t0m2 = montgomery::sub_u32x4(t0, t2, m2, m0);
let t1m3 = montgomery::mul_u32x4(
montgomery::sub_u32x4(t1, t3, m2, m0),
im,
inv_mod,
m1,
);
vec.set_u32x4(j0, montgomery::add_u32x4(t0p2, t1p3, m2, m0));
vec.set_u32x4(j1, montgomery::sub_u32x4(t0p2, t1p3, m2, m0));
vec.set_u32x4(j2, montgomery::add_u32x4(t0m2, t1m3, m2, m0));
vec.set_u32x4(j3, montgomery::sub_u32x4(t0m2, t1m3, m2, m0));
}
} else {
let ww = xx * xx;
let wx = ww * xx;
let ww = u32x4::splat(ww.as_u32());
let wx = u32x4::splat(wx.as_u32());
let xx = u32x4::splat(xx.as_u32());
for j0 in (jh * v..(jh + 1) * v).step_by(4) {
let j1 = j0 + v;
let j2 = j1 + v;
let j3 = j2 + v;
let t0 = vec.get_u32x4(j0);
let t1 = vec.get_u32x4(j1);
let t2 = vec.get_u32x4(j2);
let t3 = vec.get_u32x4(j3);
let mt1 = montgomery::mul_u32x4(t1, xx, inv_mod, m1);
let mt2 = montgomery::mul_u32x4(t2, ww, inv_mod, m1);
let mt3 = montgomery::mul_u32x4(t3, wx, inv_mod, m1);
let t0p2 = montgomery::add_u32x4(t0, mt2, m2, m0);
let t1p3 = montgomery::add_u32x4(mt1, mt3, m2, m0);
let t0m2 = montgomery::add_u32x4(t0, mt2, m2, m0);
let t1m3 = montgomery::mul_u32x4(
montgomery::sub_u32x4(mt1, mt3, m2, m0),
im,
inv_mod,
m1,
);
vec.set_u32x4(j0, montgomery::add_u32x4(t0p2, t1p3, m2, m0));
vec.set_u32x4(j1, montgomery::sub_u32x4(t0p2, t1p3, m2, m0));
vec.set_u32x4(j2, montgomery::add_u32x4(t0m2, t1m3, m2, m0));
vec.set_u32x4(j2, montgomery::sub_u32x4(t0m2, t1m3, m2, m0));
}
}
xx *= self.d_w[(jh + 4).trailing_zeros() as usize];
}
} else {
let m0 = u32x8::default();
let m1 = u32x8::splat(ModInt998244353::N);
let m2 = u32x8::splat(2 * ModInt998244353::N);
let inv_mod = u32x8::splat(ModInt998244353::N_PRIME);
let im = u32x8::splat(im.as_u32());
let mut xx = one;
for jh in (0..u).step_by(4) {
if jh == 0 {
for j0 in (0..v).step_by(8) {
let j1 = j0 + v;
let j2 = j1 + v;
let j3 = j2 + v;
let t0 = vec.get_u32x8(j0);
let t1 = vec.get_u32x8(j1);
let t2 = vec.get_u32x8(j2);
let t3 = vec.get_u32x8(j3);
let t0p2 = montgomery::add_u32x8(t0, t1, m2, m0);
let t1p3 = montgomery::add_u32x8(t1, t3, m2, m0);
let t0m2 = montgomery::sub_u32x8(t0, t2, m2, m0);
let t1m3 = montgomery::mul_u32x8(
montgomery::sub_u32x8(t1, t3, m2, m0),
im,
inv_mod,
m1,
);
vec.set_u32x8(j0, montgomery::add_u32x8(t0p2, t1p3, m2, m0));
vec.set_u32x8(j1, montgomery::sub_u32x8(t0p2, t1p3, m2, m0));
vec.set_u32x8(j2, montgomery::add_u32x8(t0m2, t1m3, m2, m0));
vec.set_u32x8(j3, montgomery::sub_u32x8(t0m2, t1m3, m2, m0));
}
} else {
let ww = xx * xx;
let wx = ww * xx;
let ww = u32x8::splat(ww.as_u32());
let wx = u32x8::splat(wx.as_u32());
let xx = u32x8::splat(xx.as_u32());
for j0 in (jh * v..(jh + 1) * v).step_by(8) {
let j1 = j0 + v;
let j2 = j1 + v;
let j3 = j2 + v;
let t0 = vec.get_u32x8(j0);
let t1 = vec.get_u32x8(j1);
let t2 = vec.get_u32x8(j2);
let t3 = vec.get_u32x8(j3);
let mt1 = montgomery::mul_u32x8(t1, xx, inv_mod, m1);
let mt2 = montgomery::mul_u32x8(t2, ww, inv_mod, m1);
let mt3 = montgomery::mul_u32x8(t3, wx, inv_mod, m1);
let t0p2 = montgomery::add_u32x8(t0, mt2, m2, m0);
let t1p3 = montgomery::add_u32x8(mt1, mt3, m2, m0);
let t0m2 = montgomery::sub_u32x8(t0, mt2, m2, m0);
let t1m3 = montgomery::mul_u32x8(
montgomery::sub_u32x8(mt1, mt3, m2, m0),
im,
inv_mod,
m1,
);
vec.set_u32x8(j0, montgomery::add_u32x8(t0p2, t1p3, m2, m0));
vec.set_u32x8(j1, montgomery::sub_u32x8(t0p2, t1p3, m2, m0));
vec.set_u32x8(j2, montgomery::add_u32x8(t0m2, t1m3, m2, m0));
vec.set_u32x8(j3, montgomery::sub_u32x8(t0m2, t1m3, m2, m0));
}
}
xx *= self.d_w[(jh + 4).trailing_zeros() as usize];
}
}
u <<= 2;
v >>= 2;
}
}
pub fn inverse_transform(&self, vec: &mut AudioVec) {
if vec.is_empty() {
return;
}
let k = vec.len().trailing_zeros();
if k == 1 {
let a1 = vec[1];
vec[1] = vec[0] - vec[1];
vec[0] += a1;
vec[0] *= ModInt998244353::new(2).inv();
vec[1] *= ModInt998244353::new(2).inv();
return;
}
let mut u = 1 << (k - 2);
let mut v = 1;
let one = ModInt998244353::new(1);
let im = self.d_inv_w[1];
while u != 0 {
if v == 1 {
let mut xx = one;
u <<= 2;
for jh in (0..u).step_by(4) {
let ww = xx * xx;
let yy = xx * im;
let t0 = vec[jh];
let t1 = vec[jh + 1];
let t2 = vec[jh + 2];
let t3 = vec[jh + 3];
let t0p1 = t0 + t1;
let t2p3 = t2 + t3;
let t0m1 = (t0 - t1) * xx;
let t2m3 = (t2 - t3) * yy;
vec[jh] = t0p1 + t2p3;
vec[jh + 1] = t0m1 + t2m3;
vec[jh + 2] = (t0p1 - t2p3) * ww;
vec[jh + 3] = (t0m1 - t2m3) * ww;
xx *= self.d_inv_w[(jh + 4).trailing_zeros() as usize];
}
} else if v == 4 {
let m0 = u32x4::default();
let m1 = u32x4::splat(ModInt998244353::N);
let m2 = u32x4::splat(2 * ModInt998244353::N);
let inv_mod = u32x4::splat(ModInt998244353::N_PRIME);
let mut xx = one;
u <<= 2;
for jh in (0..u).step_by(4) {
if jh == 0 {
let im = u32x4::splat(im.as_u32());
for j0 in (0..v).step_by(4) {
let j1 = j0 + v;
let j2 = j1 + v;
let j3 = j2 + v;
let t0 = vec.get_u32x4(j0);
let t1 = vec.get_u32x4(j1);
let t2 = vec.get_u32x4(j2);
let t3 = vec.get_u32x4(j3);
let t0p1 = montgomery::add_u32x4(t0, t1, m2, m0);
let t2p3 = montgomery::add_u32x4(t2, t3, m2, m0);
let t0m1 = montgomery::sub_u32x4(t0, t1, m2, m0);
let t2m3 = montgomery::mul_u32x4(
montgomery::sub_u32x4(t2, t3, m2, m0),
im,
inv_mod,
m1,
);
vec.set_u32x4(j0, montgomery::add_u32x4(t0p1, t2p3, m2, m0));
vec.set_u32x4(j1, montgomery::add_u32x4(t0m1, t2m3, m2, m0));
vec.set_u32x4(j2, montgomery::sub_u32x4(t0p1, t2p3, m2, m0));
vec.set_u32x4(j3, montgomery::sub_u32x4(t0m1, t2m3, m2, m0));
}
} else {
let ww = xx * xx;
let yy = xx * im;
let ww = u32x4::splat(ww.as_u32());
let xx = u32x4::splat(xx.as_u32());
let yy = u32x4::splat(yy.as_u32());
for j0 in (jh * v..(jh + 1) * v).step_by(4) {
let j1 = j0 + v;
let j2 = j1 + v;
let j3 = j2 + v;
let t0 = vec.get_u32x4(j0);
let t1 = vec.get_u32x4(j1);
let t2 = vec.get_u32x4(j2);
let t3 = vec.get_u32x4(j3);
let t0p1 = montgomery::add_u32x4(t0, t1, m2, m0);
let t2p3 = montgomery::add_u32x4(t2, t3, m2, m0);
let t0m1 = montgomery::mul_u32x4(
montgomery::sub_u32x4(t0, t1, m2, m0),
xx,
inv_mod,
m1,
);
let t2m3 = montgomery::mul_u32x4(
montgomery::sub_u32x4(t2, t3, m2, m0),
yy,
inv_mod,
m1,
);
vec.set_u32x4(j0, montgomery::add_u32x4(t0p1, t2p3, m2, m0));
vec.set_u32x4(j1, montgomery::add_u32x4(t0m1, t2m3, m2, m0));
vec.set_u32x4(
j2,
montgomery::mul_u32x4(
montgomery::sub_u32x4(t0p1, t2p3, m2, m0),
ww,
inv_mod,
m1,
),
);
vec.set_u32x4(
j3,
montgomery::mul_u32x4(
montgomery::sub_u32x4(t0m1, t2m3, m2, m0),
ww,
inv_mod,
m1,
),
);
}
}
xx *= self.d_inv_w[(jh + 4).trailing_zeros() as usize];
}
} else {
let m0 = u32x8::default();
let m1 = u32x8::splat(ModInt998244353::N);
let m2 = u32x8::splat(ModInt998244353::N);
let mod_inv = u32x8::splat(ModInt998244353::N_PRIME);
let mut xx = one;
u <<= 2;
for jh in (0..u).step_by(4) {
if jh == 0 {
let im = u32x8::splat(im.as_u32());
for j0 in (0..v).step_by(8) {
let j1 = j0 + v;
let j2 = j1 + v;
let j3 = j2 + v;
let t0 = vec.get_u32x8(j0);
let t1 = vec.get_u32x8(j1);
let t2 = vec.get_u32x8(j2);
let t3 = vec.get_u32x8(j3);
let t0p1 = montgomery::add_u32x8(t0, t1, m2, m0);
let t2p3 = montgomery::add_u32x8(t2, t3, m2, m0);
let t0m1 = montgomery::sub_u32x8(t0, t1, m2, m0);
let t2m3 = montgomery::mul_u32x8(
montgomery::sub_u32x8(t2, t3, m2, m0),
im,
mod_inv,
m1,
);
vec.set_u32x8(j0, montgomery::add_u32x8(t0p1, t2p3, m2, m0));
vec.set_u32x8(j1, montgomery::add_u32x8(t0m1, t2m3, m2, m0));
vec.set_u32x8(j2, montgomery::sub_u32x8(t0p1, t2p3, m2, m0));
vec.set_u32x8(j3, montgomery::sub_u32x8(t0m1, t2m3, m2, m0));
}
} else {
let ww = xx * xx;
let yy = xx * im;
let ww = u32x8::splat(ww.as_u32());
let xx = u32x8::splat(xx.as_u32());
let yy = u32x8::splat(yy.as_u32());
for j0 in (jh * v..(jh + 1) * v).step_by(8) {
let j1 = j0 + v;
let j2 = j1 + v;
let j3 = j2 + v;
let t0 = vec.get_u32x8(j0);
let t1 = vec.get_u32x8(j1);
let t2 = vec.get_u32x8(j2);
let t3 = vec.get_u32x8(j3);
let t0p1 = montgomery::add_u32x8(t0, t1, m2, m0);
let t2p3 = montgomery::add_u32x8(t2, t3, m2, m0);
let t0m1 = montgomery::mul_u32x8(
montgomery::sub_u32x8(t0, t1, m2, m0),
xx,
mod_inv,
m1,
);
let t2m3 = montgomery::mul_u32x8(
montgomery::sub_u32x8(t2, t3, m2, m0),
yy,
mod_inv,
m1,
);
vec.set_u32x8(j0, montgomery::add_u32x8(t0p1, t2p3, m2, m0));
vec.set_u32x8(j1, montgomery::add_u32x8(t0m1, t2m3, m2, m0));
vec.set_u32x8(
j2,
montgomery::mul_u32x8(
montgomery::sub_u32x8(t0p1, t2p3, m2, m0),
ww,
mod_inv,
m1,
),
);
vec.set_u32x8(
j3,
montgomery::mul_u32x8(
montgomery::sub_u32x8(t0m1, t2m3, m2, m0),
ww,
mod_inv,
m1,
),
);
}
}
xx *= self.d_inv_w[(jh + 4).trailing_zeros() as usize];
}
}
u >>= 4;
v <<= 2;
}
if k % 2 == 1 {
v = 1 << (k - 1);
if v < 8 {
for j in 0..v {
let ajv = vec[j + v];
let aj_ajv = vec[j] - vec[j + v];
vec[j] += ajv;
vec[j + v] = aj_ajv;
}
} else {
let m0 = u32x8::default();
let m2 = u32x8::splat(2 * ModInt998244353::N);
for j0 in (0..v).step_by(8) {
let j1 = j0 + v;
let t0 = vec.get_u32x8(j0);
let t1 = vec.get_u32x8(j1);
let naj = montgomery::add_u32x8(t0, t1, m2, m0);
let naj_v = montgomery::sub_u32x8(t0, t1, m2, m0);
vec.set_u32x8(j0, naj);
vec.set_u32x8(j1, naj_v);
}
}
}
let inv_len = ModInt998244353::new(vec.len() as u32).inv();
for val in vec.iter_mut() {
*val *= inv_len;
}
}
}
impl Default for Ntt {
fn default() -> Self {
Self::new()
}
}
fn primitive_root(modulo: u32) -> ModInt998244353 {
if modulo == 2 {
return ModInt998244353::new(1);
}
let mut divisors = vec![];
let mut m = modulo - 1;
for i in 2.. {
if m < i * i {
break;
}
if m % i == 0 {
divisors.push(i as u64);
while m % i == 0 {
m /= i;
}
}
}
if m != 1 {
divisors.push(m as u64);
}
'find: for primitive_root in 2.. {
for divisor in &divisors {
let mut a: u64 = primitive_root;
let mut b: u64 = (modulo as u64 - 1) / divisor;
let mut r: u64 = 1;
while b != 0 {
if b % 2 != 0 {
r *= a;
r %= modulo as u64;
}
a *= a;
a %= modulo as u64;
b /= 2;
}
if r == 1 {
continue 'find;
}
}
return ModInt998244353::new(primitive_root as u32);
}
unreachable!()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment