Last active
October 30, 2022 06:11
-
-
Save MikuroXina/d5f593aea5aee1bd6ab9541e28b4b9fe to your computer and use it in GitHub Desktop.
An integer modulo 998244353 with Montgomery Multiplication and Numeric Theory Transformation.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use num::{traits::Pow, One, Zero}; | |
| use serde::{Deserialize, Serialize}; | |
| const R: u64 = 1 << 32; | |
| /// Find `modulo_inv` which satisifes `modulo * modulo_inv ≡ -1 (mod R)`. | |
| const fn find_neg_inv(modulo: u32) -> u32 { | |
| let mut inv_mod = 0u32; | |
| let mut t = 0; | |
| let mut i = 1u32; | |
| loop { | |
| if t % 2 == 0 { | |
| t += modulo; | |
| inv_mod = inv_mod.wrapping_add(i); | |
| } | |
| t /= 2; | |
| if let Some(next_i) = i.checked_mul(2) { | |
| i = next_i; | |
| } else { | |
| break; | |
| } | |
| } | |
| inv_mod | |
| } | |
| const fn find_r2(modulo: u32) -> u32 { | |
| let modulo = modulo as u64; | |
| let r = R % modulo; | |
| (r * r % modulo) as u32 | |
| } | |
| pub type ModInt998244353 = ModInt<998244353>; | |
| #[test] | |
| fn const_test_998244353() { | |
| assert_eq!(ModInt998244353::N, 0x3B800001); | |
| assert_eq!(ModInt998244353::N_PRIME, 0x3B7FFFFF); | |
| assert_eq!(ModInt998244353::R2, 0x378DFBC6); | |
| } | |
| /// The Montgomery Form modulo `MOD`, multiplied `2^32`. `MOD` is expected to a prime number. | |
| #[derive( | |
| Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, | |
| )] | |
| #[repr(transparent)] | |
| pub struct ModInt<const MOD: u32>(u32); | |
| impl<const MOD: u32> From<ModInt<MOD>> for u32 { | |
| #[inline] | |
| fn from(mod_int: ModInt<MOD>) -> Self { | |
| mod_int.as_u32() | |
| } | |
| } | |
| impl<const MOD: u32> ModInt<MOD> { | |
| #[inline] | |
| pub fn new(n: u32) -> Self { | |
| Self(Self::reduce(n as u64 * Self::R2 as u64)) | |
| } | |
| #[inline] | |
| pub fn as_u32(self) -> u32 { | |
| Self::reduce(self.0 as u64) | |
| } | |
| /// The modulo integer, which must be a prime number. | |
| pub const N: u32 = MOD; | |
| /// The modular inverse of `N`, which satisifies `N * N_PRIME ≡ -1`. | |
| pub const N_PRIME: u32 = find_neg_inv(MOD); | |
| /// The squared number of the multiplier `R ≡ 2^32` for Montgomery Form, which satisfies `R2 ≡ 2^64 (mod N)`. | |
| pub const R2: u32 = find_r2(MOD); | |
| #[inline] | |
| pub fn inv(self) -> Self { | |
| self.pow(MOD - 2) | |
| } | |
| #[inline] | |
| pub fn reduce(x: u64) -> u32 { | |
| let modulo = MOD as u64; | |
| debug_assert!(x < modulo * R as u64); | |
| let x_n_prime = (x as u32).wrapping_mul(Self::N_PRIME) as u64; | |
| let mul = (x + x_n_prime * modulo) / R; | |
| let ret = if modulo <= mul { mul - modulo } else { mul }; | |
| debug_assert!(ret < modulo); | |
| ret as u32 | |
| } | |
| } | |
| macro_rules! impl_from_for_mod_int { | |
| ($t:ty) => { | |
| impl<const MOD: u32> From<$t> for ModInt<MOD> { | |
| #[inline] | |
| fn from(x: $t) -> Self { | |
| Self::new(x as u32) | |
| } | |
| } | |
| impl<const MOD: u32> From<&'_ $t> for ModInt<MOD> { | |
| #[inline] | |
| fn from(&x: &'_ $t) -> Self { | |
| Self::new(x as u32) | |
| } | |
| } | |
| }; | |
| } | |
| impl_from_for_mod_int!(u64); | |
| impl_from_for_mod_int!(u32); | |
| impl_from_for_mod_int!(i32); | |
| impl<const MOD: u32> std::ops::Add for ModInt<MOD> { | |
| type Output = Self; | |
| #[inline] | |
| fn add(mut self, rhs: Self) -> Self::Output { | |
| self += rhs; | |
| self | |
| } | |
| } | |
| impl<const MOD: u32> std::ops::AddAssign for ModInt<MOD> { | |
| #[inline] | |
| fn add_assign(&mut self, rhs: Self) { | |
| self.0 += rhs.0; | |
| if MOD <= self.0 { | |
| self.0 -= MOD; | |
| } | |
| } | |
| } | |
| impl<const MOD: u32> std::ops::Sub for ModInt<MOD> { | |
| type Output = Self; | |
| #[inline] | |
| fn sub(mut self, rhs: Self) -> Self::Output { | |
| self -= rhs; | |
| self | |
| } | |
| } | |
| impl<const MOD: u32> std::ops::SubAssign for ModInt<MOD> { | |
| #[inline] | |
| fn sub_assign(&mut self, rhs: Self) { | |
| if let Some(sub) = self.0.checked_sub(rhs.0) { | |
| self.0 = sub; | |
| } else { | |
| self.0 += MOD; | |
| self.0 -= rhs.0; | |
| } | |
| } | |
| } | |
| impl<const MOD: u32> std::ops::Mul for ModInt<MOD> { | |
| type Output = Self; | |
| #[inline] | |
| fn mul(mut self, rhs: Self) -> Self::Output { | |
| self *= rhs; | |
| self | |
| } | |
| } | |
| impl<const MOD: u32> std::ops::MulAssign for ModInt<MOD> { | |
| #[inline] | |
| fn mul_assign(&mut self, rhs: Self) { | |
| self.0 = Self::reduce(self.0 as u64 * rhs.0 as u64); | |
| } | |
| } | |
| impl<const MOD: u32> std::ops::Div for ModInt<MOD> { | |
| type Output = Self; | |
| #[inline] | |
| fn div(mut self, rhs: Self) -> Self::Output { | |
| self /= rhs; | |
| self | |
| } | |
| } | |
| impl<const MOD: u32> std::ops::DivAssign for ModInt<MOD> { | |
| #[inline] | |
| #[allow(clippy::suspicious_op_assign_impl)] | |
| fn div_assign(&mut self, rhs: Self) { | |
| *self *= rhs.inv(); | |
| } | |
| } | |
| impl<const MOD: u32> std::ops::Neg for ModInt<MOD> { | |
| type Output = Self; | |
| #[inline] | |
| fn neg(self) -> Self::Output { | |
| Self(0) - self | |
| } | |
| } | |
| impl<const MOD: u32> std::iter::Sum for ModInt<MOD> { | |
| #[inline] | |
| fn sum<I>(iter: I) -> Self | |
| where | |
| I: Iterator<Item = Self>, | |
| { | |
| iter.fold(Self(0), |a, b| a + b) | |
| } | |
| } | |
| impl<const MOD: u32> std::iter::Product for ModInt<MOD> { | |
| #[inline] | |
| fn product<I>(iter: I) -> Self | |
| where | |
| I: Iterator<Item = Self>, | |
| { | |
| iter.fold(Self::new(1), |a, b| a * b) | |
| } | |
| } | |
| impl<const MOD: u32> Zero for ModInt<MOD> { | |
| #[inline] | |
| fn zero() -> Self { | |
| Self(0) | |
| } | |
| #[inline] | |
| fn is_zero(&self) -> bool { | |
| self.0 == 0 | |
| } | |
| } | |
| impl<const MOD: u32> One for ModInt<MOD> { | |
| #[inline] | |
| fn one() -> Self { | |
| Self::new(1) | |
| } | |
| } | |
| impl<const MOD: u32> Pow<u32> for ModInt<MOD> { | |
| type Output = Self; | |
| #[inline] | |
| fn pow(mut self, mut exp: u32) -> Self::Output { | |
| if exp == 0 { | |
| return Self::new(1); | |
| } | |
| let mut y = Self::new(1); | |
| while 0 < exp { | |
| if exp % 2 == 1 { | |
| y *= self; | |
| } | |
| self *= self; | |
| exp /= 2; | |
| } | |
| y | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| use bytemuck::cast; | |
| use wide::{i32x4, i32x8, i64x2, i64x4, u32x4, u32x8, CmpGt}; | |
| #[inline] | |
| pub fn mul_u32x4(a: u32x4, b: u32x4, inv_n: u32x4, m1: u32x4) -> u32x4 { | |
| cast( | |
| cast::<_, i32x4>(mul_hi_u32x4(a, b)) + cast::<_, i32x4>(m1) | |
| - cast::<_, i32x4>(mul_hi_u32x4(a * b * inv_n, m1)), | |
| ) | |
| } | |
| #[inline] | |
| fn mul_hi_u32x4(a: u32x4, b: u32x4) -> u32x4 { | |
| let a_inner = cast::<_, i32x4>(a).to_array(); | |
| let a13 = i32x4::new([a_inner[1], a_inner[1], a_inner[3], a_inner[3]]); | |
| let b_inner = cast::<_, i32x4>(b).to_array(); | |
| let b13 = i32x4::new([b_inner[1], b_inner[1], b_inner[3], b_inner[3]]); | |
| let prod02 = cast::<_, i32x4>(a * b).to_array(); | |
| let prod13 = cast::<_, i32x4>(a13 * b13).to_array(); | |
| let prod_lo = | |
| cast::<_, i64x2>(i32x4::new([prod02[0], prod13[0], prod02[1], prod13[1]])).to_array(); | |
| let prod_hi = | |
| cast::<_, i64x2>(i32x4::new([prod02[2], prod13[2], prod02[3], prod13[3]])).to_array(); | |
| cast(i64x2::new([prod_lo[1], prod_hi[1]])) | |
| } | |
| #[inline] | |
| pub fn add_u32x4(a: u32x4, b: u32x4, m2: u32x4, m0: u32x4) -> u32x4 { | |
| let ret = cast::<_, i32x4>(a) + cast::<_, i32x4>(b) - cast::<_, i32x4>(m2); | |
| cast(cast::<_, i32x4>(cast::<_, u32x4>(cast::<_, i32x4>(m0).cmp_gt(ret)) & m2) + ret) | |
| } | |
| #[inline] | |
| pub fn sub_u32x4(a: u32x4, b: u32x4, m2: u32x4, m0: u32x4) -> u32x4 { | |
| let ret = cast::<_, i32x4>(a) - cast::<_, i32x4>(b); | |
| cast(cast::<_, i32x4>(cast::<_, u32x4>(cast::<_, i32x4>(m0).cmp_gt(ret)) & m2) + ret) | |
| } | |
| #[inline] | |
| pub fn mul_u32x8(a: u32x8, b: u32x8, inv_n: u32x8, m1: u32x8) -> u32x8 { | |
| cast( | |
| cast::<_, i32x8>(mul_hi_u32x8(a, b)) + cast::<_, i32x8>(m1) | |
| - cast::<_, i32x8>(mul_hi_u32x8(a * b * inv_n, m1)), | |
| ) | |
| } | |
| #[inline] | |
| fn mul_hi_u32x8(a: u32x8, b: u32x8) -> u32x8 { | |
| let a_inner = cast::<_, i32x8>(a).to_array(); | |
| let a13 = i32x8::new([ | |
| a_inner[1], a_inner[1], a_inner[3], a_inner[3], a_inner[5], a_inner[5], a_inner[7], | |
| a_inner[7], | |
| ]); | |
| let b_inner = cast::<_, i32x8>(b).to_array(); | |
| let b13 = i32x8::new([ | |
| b_inner[1], b_inner[1], b_inner[3], b_inner[3], b_inner[5], b_inner[5], b_inner[7], | |
| b_inner[7], | |
| ]); | |
| let prod02 = cast::<_, i32x8>(a * b).to_array(); | |
| let prod13 = cast::<_, i32x8>(a13 * b13).to_array(); | |
| let prod_lo = cast::<_, i64x4>(i32x8::new([ | |
| prod02[0], prod13[0], prod02[1], prod13[1], prod02[2], prod13[2], prod02[3], prod13[3], | |
| ])) | |
| .to_array(); | |
| let prod_hi = cast::<_, i64x4>(i32x8::new([ | |
| prod02[4], prod13[4], prod02[5], prod13[5], prod02[6], prod13[6], prod02[7], prod13[7], | |
| ])) | |
| .to_array(); | |
| cast(i64x4::new([prod_lo[1], prod_hi[1], prod_lo[3], prod_hi[3]])) | |
| } | |
| #[inline] | |
| pub fn add_u32x8(a: u32x8, b: u32x8, m2: u32x8, m0: u32x8) -> u32x8 { | |
| let ret = cast::<_, i32x8>(a) + cast::<_, i32x8>(b) - cast::<_, i32x8>(m2); | |
| cast(cast::<_, i32x8>(cast::<_, u32x8>(cast::<_, i32x8>(m0).cmp_gt(ret)) & m2) + ret) | |
| } | |
| #[inline] | |
| pub fn sub_u32x8(a: u32x8, b: u32x8, m2: u32x8, m0: u32x8) -> u32x8 { | |
| let ret = cast::<_, i32x8>(a) - cast::<_, i32x8>(b); | |
| cast(cast::<_, i32x8>(cast::<_, u32x8>(cast::<_, i32x8>(m0).cmp_gt(ret)) & m2) + ret) | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| /// The maximum n-ary of the component can be transformed. | |
| const LEVEL: usize = (ModInt998244353::N - 1).trailing_zeros() as usize; | |
| #[derive(Debug, Clone, PartialEq, Eq, Hash)] | |
| pub struct Ntt { | |
| primitive_root: ModInt998244353, | |
| d_w: [ModInt998244353; LEVEL], | |
| d_inv_w: [ModInt998244353; LEVEL], | |
| } | |
| impl Ntt { | |
| pub fn new() -> Self { | |
| let modulo = ModInt998244353::N; | |
| let primitive_root = primitive_root(modulo); | |
| let mut w = [ModInt998244353::default(); LEVEL]; | |
| let mut d_w = [ModInt998244353::default(); LEVEL]; | |
| let mut inv_w = [ModInt998244353::default(); LEVEL]; | |
| let mut d_inv_w = [ModInt998244353::default(); LEVEL]; | |
| w[LEVEL - 1] = primitive_root.pow((modulo - 1) / (1 << LEVEL)); | |
| inv_w[LEVEL - 1] = w[LEVEL - 1].inv(); | |
| for i in (0..LEVEL - 1).rev() { | |
| w[i] = w[i + 1] * w[i + 1]; | |
| inv_w[i] = inv_w[i + 1] * inv_w[i + 1]; | |
| } | |
| d_w[0] = w[1] * w[1]; | |
| d_inv_w[0] = d_w[0]; | |
| d_w[1] = w[1]; | |
| d_inv_w[1] = inv_w[1]; | |
| d_w[2] = w[2]; | |
| d_inv_w[2] = w[2]; | |
| for i in 3..LEVEL { | |
| d_w[i] = d_w[i - 1] * inv_w[i - 2] * w[i]; | |
| d_inv_w[i] = d_inv_w[i - 1] * w[i - 2] * inv_w[i]; | |
| } | |
| Self { | |
| primitive_root, | |
| d_w, | |
| d_inv_w, | |
| } | |
| } | |
| pub fn transform(&self, vec: &mut AudioVec) { | |
| if vec.is_empty() { | |
| return; | |
| } | |
| let k = vec.len().trailing_zeros(); | |
| if k == 1 { | |
| let a = vec[1]; | |
| vec[1] = vec[0] - vec[1]; | |
| vec[0] += a; | |
| return; | |
| } | |
| if k % 2 != 0 { | |
| let v = 1 << (k - 1); | |
| if v < 8 { | |
| for j in 0..v { | |
| let jv = vec[j + v]; | |
| vec[j + v] = vec[j] - jv; | |
| vec[j] += jv; | |
| } | |
| } else { | |
| let m0 = u32x8::default(); | |
| let m2 = u32x8::splat(2 * ModInt998244353::N); | |
| for j0 in (0..v).step_by(8) { | |
| let j1 = j0 + v; | |
| let t0 = vec.get_u32x8(j0); | |
| let t1 = vec.get_u32x8(j1); | |
| let naj = montgomery::add_u32x8(t0, t1, m2, m0); | |
| let naj_v = montgomery::sub_u32x8(t0, t1, m2, m0); | |
| vec.set_u32x8(j0, naj); | |
| vec.set_u32x8(j1, naj_v); | |
| } | |
| } | |
| } | |
| let mut u = 1 << (2 + (k % 2)); | |
| let mut v = 1 << (k - 2 - (k % 2)); | |
| let one = ModInt998244353::new(1); | |
| let im = self.d_w[1]; | |
| while v != 0 { | |
| if v == 1 { | |
| let mut xx = one; | |
| for jh in (0..u).step_by(4) { | |
| let ww = xx * xx; | |
| let wx = ww * xx; | |
| let t0 = vec[jh]; | |
| let t1 = vec[jh + 1] * xx; | |
| let t2 = vec[jh + 2] * ww; | |
| let t3 = vec[jh + 3] * wx; | |
| let t0p2 = t0 + t2; | |
| let t1p3 = t1 + t3; | |
| let t0m2 = t0 - t2; | |
| let t1m3 = (t1 - t3) * im; | |
| vec[jh] = t0p2 + t1p3; | |
| vec[jh + 1] = t0p2 - t1p3; | |
| vec[jh + 2] = t0m2 + t1m3; | |
| vec[jh + 3] = t0m2 - t1m3; | |
| xx *= self.d_w[(jh + 4).trailing_zeros() as usize]; | |
| } | |
| } else if v == 4 { | |
| let mut xx = one; | |
| let m0 = u32x4::default(); | |
| let m1 = u32x4::splat(ModInt998244353::N); | |
| let m2 = u32x4::splat(2 * ModInt998244353::N); | |
| let inv_mod = u32x4::splat(ModInt998244353::N_PRIME); | |
| let im = u32x4::splat(im.as_u32()); | |
| for jh in (0..u).step_by(4) { | |
| if jh == 0 { | |
| for j0 in (0..v).step_by(4) { | |
| let j1 = j0 + v; | |
| let j2 = j1 + v; | |
| let j3 = j2 + v; | |
| let t0 = vec.get_u32x4(j0); | |
| let t1 = vec.get_u32x4(j1); | |
| let t2 = vec.get_u32x4(j2); | |
| let t3 = vec.get_u32x4(j3); | |
| let t0p2 = montgomery::add_u32x4(t0, t2, m2, m0); | |
| let t1p3 = montgomery::add_u32x4(t1, t3, m2, m0); | |
| let t0m2 = montgomery::sub_u32x4(t0, t2, m2, m0); | |
| let t1m3 = montgomery::mul_u32x4( | |
| montgomery::sub_u32x4(t1, t3, m2, m0), | |
| im, | |
| inv_mod, | |
| m1, | |
| ); | |
| vec.set_u32x4(j0, montgomery::add_u32x4(t0p2, t1p3, m2, m0)); | |
| vec.set_u32x4(j1, montgomery::sub_u32x4(t0p2, t1p3, m2, m0)); | |
| vec.set_u32x4(j2, montgomery::add_u32x4(t0m2, t1m3, m2, m0)); | |
| vec.set_u32x4(j3, montgomery::sub_u32x4(t0m2, t1m3, m2, m0)); | |
| } | |
| } else { | |
| let ww = xx * xx; | |
| let wx = ww * xx; | |
| let ww = u32x4::splat(ww.as_u32()); | |
| let wx = u32x4::splat(wx.as_u32()); | |
| let xx = u32x4::splat(xx.as_u32()); | |
| for j0 in (jh * v..(jh + 1) * v).step_by(4) { | |
| let j1 = j0 + v; | |
| let j2 = j1 + v; | |
| let j3 = j2 + v; | |
| let t0 = vec.get_u32x4(j0); | |
| let t1 = vec.get_u32x4(j1); | |
| let t2 = vec.get_u32x4(j2); | |
| let t3 = vec.get_u32x4(j3); | |
| let mt1 = montgomery::mul_u32x4(t1, xx, inv_mod, m1); | |
| let mt2 = montgomery::mul_u32x4(t2, ww, inv_mod, m1); | |
| let mt3 = montgomery::mul_u32x4(t3, wx, inv_mod, m1); | |
| let t0p2 = montgomery::add_u32x4(t0, mt2, m2, m0); | |
| let t1p3 = montgomery::add_u32x4(mt1, mt3, m2, m0); | |
| let t0m2 = montgomery::add_u32x4(t0, mt2, m2, m0); | |
| let t1m3 = montgomery::mul_u32x4( | |
| montgomery::sub_u32x4(mt1, mt3, m2, m0), | |
| im, | |
| inv_mod, | |
| m1, | |
| ); | |
| vec.set_u32x4(j0, montgomery::add_u32x4(t0p2, t1p3, m2, m0)); | |
| vec.set_u32x4(j1, montgomery::sub_u32x4(t0p2, t1p3, m2, m0)); | |
| vec.set_u32x4(j2, montgomery::add_u32x4(t0m2, t1m3, m2, m0)); | |
| vec.set_u32x4(j2, montgomery::sub_u32x4(t0m2, t1m3, m2, m0)); | |
| } | |
| } | |
| xx *= self.d_w[(jh + 4).trailing_zeros() as usize]; | |
| } | |
| } else { | |
| let m0 = u32x8::default(); | |
| let m1 = u32x8::splat(ModInt998244353::N); | |
| let m2 = u32x8::splat(2 * ModInt998244353::N); | |
| let inv_mod = u32x8::splat(ModInt998244353::N_PRIME); | |
| let im = u32x8::splat(im.as_u32()); | |
| let mut xx = one; | |
| for jh in (0..u).step_by(4) { | |
| if jh == 0 { | |
| for j0 in (0..v).step_by(8) { | |
| let j1 = j0 + v; | |
| let j2 = j1 + v; | |
| let j3 = j2 + v; | |
| let t0 = vec.get_u32x8(j0); | |
| let t1 = vec.get_u32x8(j1); | |
| let t2 = vec.get_u32x8(j2); | |
| let t3 = vec.get_u32x8(j3); | |
| let t0p2 = montgomery::add_u32x8(t0, t1, m2, m0); | |
| let t1p3 = montgomery::add_u32x8(t1, t3, m2, m0); | |
| let t0m2 = montgomery::sub_u32x8(t0, t2, m2, m0); | |
| let t1m3 = montgomery::mul_u32x8( | |
| montgomery::sub_u32x8(t1, t3, m2, m0), | |
| im, | |
| inv_mod, | |
| m1, | |
| ); | |
| vec.set_u32x8(j0, montgomery::add_u32x8(t0p2, t1p3, m2, m0)); | |
| vec.set_u32x8(j1, montgomery::sub_u32x8(t0p2, t1p3, m2, m0)); | |
| vec.set_u32x8(j2, montgomery::add_u32x8(t0m2, t1m3, m2, m0)); | |
| vec.set_u32x8(j3, montgomery::sub_u32x8(t0m2, t1m3, m2, m0)); | |
| } | |
| } else { | |
| let ww = xx * xx; | |
| let wx = ww * xx; | |
| let ww = u32x8::splat(ww.as_u32()); | |
| let wx = u32x8::splat(wx.as_u32()); | |
| let xx = u32x8::splat(xx.as_u32()); | |
| for j0 in (jh * v..(jh + 1) * v).step_by(8) { | |
| let j1 = j0 + v; | |
| let j2 = j1 + v; | |
| let j3 = j2 + v; | |
| let t0 = vec.get_u32x8(j0); | |
| let t1 = vec.get_u32x8(j1); | |
| let t2 = vec.get_u32x8(j2); | |
| let t3 = vec.get_u32x8(j3); | |
| let mt1 = montgomery::mul_u32x8(t1, xx, inv_mod, m1); | |
| let mt2 = montgomery::mul_u32x8(t2, ww, inv_mod, m1); | |
| let mt3 = montgomery::mul_u32x8(t3, wx, inv_mod, m1); | |
| let t0p2 = montgomery::add_u32x8(t0, mt2, m2, m0); | |
| let t1p3 = montgomery::add_u32x8(mt1, mt3, m2, m0); | |
| let t0m2 = montgomery::sub_u32x8(t0, mt2, m2, m0); | |
| let t1m3 = montgomery::mul_u32x8( | |
| montgomery::sub_u32x8(mt1, mt3, m2, m0), | |
| im, | |
| inv_mod, | |
| m1, | |
| ); | |
| vec.set_u32x8(j0, montgomery::add_u32x8(t0p2, t1p3, m2, m0)); | |
| vec.set_u32x8(j1, montgomery::sub_u32x8(t0p2, t1p3, m2, m0)); | |
| vec.set_u32x8(j2, montgomery::add_u32x8(t0m2, t1m3, m2, m0)); | |
| vec.set_u32x8(j3, montgomery::sub_u32x8(t0m2, t1m3, m2, m0)); | |
| } | |
| } | |
| xx *= self.d_w[(jh + 4).trailing_zeros() as usize]; | |
| } | |
| } | |
| u <<= 2; | |
| v >>= 2; | |
| } | |
| } | |
| pub fn inverse_transform(&self, vec: &mut AudioVec) { | |
| if vec.is_empty() { | |
| return; | |
| } | |
| let k = vec.len().trailing_zeros(); | |
| if k == 1 { | |
| let a1 = vec[1]; | |
| vec[1] = vec[0] - vec[1]; | |
| vec[0] += a1; | |
| vec[0] *= ModInt998244353::new(2).inv(); | |
| vec[1] *= ModInt998244353::new(2).inv(); | |
| return; | |
| } | |
| let mut u = 1 << (k - 2); | |
| let mut v = 1; | |
| let one = ModInt998244353::new(1); | |
| let im = self.d_inv_w[1]; | |
| while u != 0 { | |
| if v == 1 { | |
| let mut xx = one; | |
| u <<= 2; | |
| for jh in (0..u).step_by(4) { | |
| let ww = xx * xx; | |
| let yy = xx * im; | |
| let t0 = vec[jh]; | |
| let t1 = vec[jh + 1]; | |
| let t2 = vec[jh + 2]; | |
| let t3 = vec[jh + 3]; | |
| let t0p1 = t0 + t1; | |
| let t2p3 = t2 + t3; | |
| let t0m1 = (t0 - t1) * xx; | |
| let t2m3 = (t2 - t3) * yy; | |
| vec[jh] = t0p1 + t2p3; | |
| vec[jh + 1] = t0m1 + t2m3; | |
| vec[jh + 2] = (t0p1 - t2p3) * ww; | |
| vec[jh + 3] = (t0m1 - t2m3) * ww; | |
| xx *= self.d_inv_w[(jh + 4).trailing_zeros() as usize]; | |
| } | |
| } else if v == 4 { | |
| let m0 = u32x4::default(); | |
| let m1 = u32x4::splat(ModInt998244353::N); | |
| let m2 = u32x4::splat(2 * ModInt998244353::N); | |
| let inv_mod = u32x4::splat(ModInt998244353::N_PRIME); | |
| let mut xx = one; | |
| u <<= 2; | |
| for jh in (0..u).step_by(4) { | |
| if jh == 0 { | |
| let im = u32x4::splat(im.as_u32()); | |
| for j0 in (0..v).step_by(4) { | |
| let j1 = j0 + v; | |
| let j2 = j1 + v; | |
| let j3 = j2 + v; | |
| let t0 = vec.get_u32x4(j0); | |
| let t1 = vec.get_u32x4(j1); | |
| let t2 = vec.get_u32x4(j2); | |
| let t3 = vec.get_u32x4(j3); | |
| let t0p1 = montgomery::add_u32x4(t0, t1, m2, m0); | |
| let t2p3 = montgomery::add_u32x4(t2, t3, m2, m0); | |
| let t0m1 = montgomery::sub_u32x4(t0, t1, m2, m0); | |
| let t2m3 = montgomery::mul_u32x4( | |
| montgomery::sub_u32x4(t2, t3, m2, m0), | |
| im, | |
| inv_mod, | |
| m1, | |
| ); | |
| vec.set_u32x4(j0, montgomery::add_u32x4(t0p1, t2p3, m2, m0)); | |
| vec.set_u32x4(j1, montgomery::add_u32x4(t0m1, t2m3, m2, m0)); | |
| vec.set_u32x4(j2, montgomery::sub_u32x4(t0p1, t2p3, m2, m0)); | |
| vec.set_u32x4(j3, montgomery::sub_u32x4(t0m1, t2m3, m2, m0)); | |
| } | |
| } else { | |
| let ww = xx * xx; | |
| let yy = xx * im; | |
| let ww = u32x4::splat(ww.as_u32()); | |
| let xx = u32x4::splat(xx.as_u32()); | |
| let yy = u32x4::splat(yy.as_u32()); | |
| for j0 in (jh * v..(jh + 1) * v).step_by(4) { | |
| let j1 = j0 + v; | |
| let j2 = j1 + v; | |
| let j3 = j2 + v; | |
| let t0 = vec.get_u32x4(j0); | |
| let t1 = vec.get_u32x4(j1); | |
| let t2 = vec.get_u32x4(j2); | |
| let t3 = vec.get_u32x4(j3); | |
| let t0p1 = montgomery::add_u32x4(t0, t1, m2, m0); | |
| let t2p3 = montgomery::add_u32x4(t2, t3, m2, m0); | |
| let t0m1 = montgomery::mul_u32x4( | |
| montgomery::sub_u32x4(t0, t1, m2, m0), | |
| xx, | |
| inv_mod, | |
| m1, | |
| ); | |
| let t2m3 = montgomery::mul_u32x4( | |
| montgomery::sub_u32x4(t2, t3, m2, m0), | |
| yy, | |
| inv_mod, | |
| m1, | |
| ); | |
| vec.set_u32x4(j0, montgomery::add_u32x4(t0p1, t2p3, m2, m0)); | |
| vec.set_u32x4(j1, montgomery::add_u32x4(t0m1, t2m3, m2, m0)); | |
| vec.set_u32x4( | |
| j2, | |
| montgomery::mul_u32x4( | |
| montgomery::sub_u32x4(t0p1, t2p3, m2, m0), | |
| ww, | |
| inv_mod, | |
| m1, | |
| ), | |
| ); | |
| vec.set_u32x4( | |
| j3, | |
| montgomery::mul_u32x4( | |
| montgomery::sub_u32x4(t0m1, t2m3, m2, m0), | |
| ww, | |
| inv_mod, | |
| m1, | |
| ), | |
| ); | |
| } | |
| } | |
| xx *= self.d_inv_w[(jh + 4).trailing_zeros() as usize]; | |
| } | |
| } else { | |
| let m0 = u32x8::default(); | |
| let m1 = u32x8::splat(ModInt998244353::N); | |
| let m2 = u32x8::splat(ModInt998244353::N); | |
| let mod_inv = u32x8::splat(ModInt998244353::N_PRIME); | |
| let mut xx = one; | |
| u <<= 2; | |
| for jh in (0..u).step_by(4) { | |
| if jh == 0 { | |
| let im = u32x8::splat(im.as_u32()); | |
| for j0 in (0..v).step_by(8) { | |
| let j1 = j0 + v; | |
| let j2 = j1 + v; | |
| let j3 = j2 + v; | |
| let t0 = vec.get_u32x8(j0); | |
| let t1 = vec.get_u32x8(j1); | |
| let t2 = vec.get_u32x8(j2); | |
| let t3 = vec.get_u32x8(j3); | |
| let t0p1 = montgomery::add_u32x8(t0, t1, m2, m0); | |
| let t2p3 = montgomery::add_u32x8(t2, t3, m2, m0); | |
| let t0m1 = montgomery::sub_u32x8(t0, t1, m2, m0); | |
| let t2m3 = montgomery::mul_u32x8( | |
| montgomery::sub_u32x8(t2, t3, m2, m0), | |
| im, | |
| mod_inv, | |
| m1, | |
| ); | |
| vec.set_u32x8(j0, montgomery::add_u32x8(t0p1, t2p3, m2, m0)); | |
| vec.set_u32x8(j1, montgomery::add_u32x8(t0m1, t2m3, m2, m0)); | |
| vec.set_u32x8(j2, montgomery::sub_u32x8(t0p1, t2p3, m2, m0)); | |
| vec.set_u32x8(j3, montgomery::sub_u32x8(t0m1, t2m3, m2, m0)); | |
| } | |
| } else { | |
| let ww = xx * xx; | |
| let yy = xx * im; | |
| let ww = u32x8::splat(ww.as_u32()); | |
| let xx = u32x8::splat(xx.as_u32()); | |
| let yy = u32x8::splat(yy.as_u32()); | |
| for j0 in (jh * v..(jh + 1) * v).step_by(8) { | |
| let j1 = j0 + v; | |
| let j2 = j1 + v; | |
| let j3 = j2 + v; | |
| let t0 = vec.get_u32x8(j0); | |
| let t1 = vec.get_u32x8(j1); | |
| let t2 = vec.get_u32x8(j2); | |
| let t3 = vec.get_u32x8(j3); | |
| let t0p1 = montgomery::add_u32x8(t0, t1, m2, m0); | |
| let t2p3 = montgomery::add_u32x8(t2, t3, m2, m0); | |
| let t0m1 = montgomery::mul_u32x8( | |
| montgomery::sub_u32x8(t0, t1, m2, m0), | |
| xx, | |
| mod_inv, | |
| m1, | |
| ); | |
| let t2m3 = montgomery::mul_u32x8( | |
| montgomery::sub_u32x8(t2, t3, m2, m0), | |
| yy, | |
| mod_inv, | |
| m1, | |
| ); | |
| vec.set_u32x8(j0, montgomery::add_u32x8(t0p1, t2p3, m2, m0)); | |
| vec.set_u32x8(j1, montgomery::add_u32x8(t0m1, t2m3, m2, m0)); | |
| vec.set_u32x8( | |
| j2, | |
| montgomery::mul_u32x8( | |
| montgomery::sub_u32x8(t0p1, t2p3, m2, m0), | |
| ww, | |
| mod_inv, | |
| m1, | |
| ), | |
| ); | |
| vec.set_u32x8( | |
| j3, | |
| montgomery::mul_u32x8( | |
| montgomery::sub_u32x8(t0m1, t2m3, m2, m0), | |
| ww, | |
| mod_inv, | |
| m1, | |
| ), | |
| ); | |
| } | |
| } | |
| xx *= self.d_inv_w[(jh + 4).trailing_zeros() as usize]; | |
| } | |
| } | |
| u >>= 4; | |
| v <<= 2; | |
| } | |
| if k % 2 == 1 { | |
| v = 1 << (k - 1); | |
| if v < 8 { | |
| for j in 0..v { | |
| let ajv = vec[j + v]; | |
| let aj_ajv = vec[j] - vec[j + v]; | |
| vec[j] += ajv; | |
| vec[j + v] = aj_ajv; | |
| } | |
| } else { | |
| let m0 = u32x8::default(); | |
| let m2 = u32x8::splat(2 * ModInt998244353::N); | |
| for j0 in (0..v).step_by(8) { | |
| let j1 = j0 + v; | |
| let t0 = vec.get_u32x8(j0); | |
| let t1 = vec.get_u32x8(j1); | |
| let naj = montgomery::add_u32x8(t0, t1, m2, m0); | |
| let naj_v = montgomery::sub_u32x8(t0, t1, m2, m0); | |
| vec.set_u32x8(j0, naj); | |
| vec.set_u32x8(j1, naj_v); | |
| } | |
| } | |
| } | |
| let inv_len = ModInt998244353::new(vec.len() as u32).inv(); | |
| for val in vec.iter_mut() { | |
| *val *= inv_len; | |
| } | |
| } | |
| } | |
| impl Default for Ntt { | |
| fn default() -> Self { | |
| Self::new() | |
| } | |
| } | |
| fn primitive_root(modulo: u32) -> ModInt998244353 { | |
| if modulo == 2 { | |
| return ModInt998244353::new(1); | |
| } | |
| let mut divisors = vec![]; | |
| let mut m = modulo - 1; | |
| for i in 2.. { | |
| if m < i * i { | |
| break; | |
| } | |
| if m % i == 0 { | |
| divisors.push(i as u64); | |
| while m % i == 0 { | |
| m /= i; | |
| } | |
| } | |
| } | |
| if m != 1 { | |
| divisors.push(m as u64); | |
| } | |
| 'find: for primitive_root in 2.. { | |
| for divisor in &divisors { | |
| let mut a: u64 = primitive_root; | |
| let mut b: u64 = (modulo as u64 - 1) / divisor; | |
| let mut r: u64 = 1; | |
| while b != 0 { | |
| if b % 2 != 0 { | |
| r *= a; | |
| r %= modulo as u64; | |
| } | |
| a *= a; | |
| a %= modulo as u64; | |
| b /= 2; | |
| } | |
| if r == 1 { | |
| continue 'find; | |
| } | |
| } | |
| return ModInt998244353::new(primitive_root as u32); | |
| } | |
| unreachable!() | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment