Last active
September 21, 2020 05:05
-
-
Save math314/6a08301b8b75b8172798 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cstdio> | |
#include <cassert> | |
#include <vector> | |
using namespace std; | |
typedef long long ll; | |
typedef pair<int, int> Pii; | |
#define FOR(i,n) for(int i = 0; i < (n); i++) | |
#define sz(c) ((int)(c).size()) | |
#define ten(x) ((int)1e##x) | |
template<class T> T extgcd(T a, T b, T& x, T& y) { for (T u = y = 1, v = x = 0; a;) { T q = b / a; swap(x -= q * u, u); swap(y -= q * v, v); swap(b -= q * a, a); } return b; } | |
template<class T> T mod_inv(T a, T m) { T x, y; extgcd(a, m, x, y); return (m + x % m) % m; } | |
ll mod_pow(ll a, ll n, ll mod) { ll ret = 1; ll p = a % mod; while (n) { if (n & 1) ret = ret * p % mod; p = p * p % mod; n >>= 1; } return ret; } | |
template<int mod, int primitive_root> | |
class NTT { | |
public: | |
int get_mod() const { return mod; } | |
void _ntt(vector<ll>& a, int sign) { | |
const int n = sz(a); | |
assert((n ^ (n&-n)) == 0); //n = 2^k | |
const int g = 3; //g is primitive root of mod | |
int h = (int)mod_pow(g, (mod - 1) / n, mod); // h^n = 1 | |
if (sign == -1) h = (int)mod_inv(h, mod); //h = h^-1 % mod | |
//bit reverse | |
int i = 0; | |
for (int j = 1; j < n - 1; ++j) { | |
for (int k = n >> 1; k >(i ^= k); k >>= 1); | |
if (j < i) swap(a[i], a[j]); | |
} | |
for (int m = 1; m < n; m *= 2) { | |
const int m2 = 2 * m; | |
const ll base = mod_pow(h, n / m2, mod); | |
ll w = 1; | |
FOR(x, m) { | |
for (int s = x; s < n; s += m2) { | |
ll u = a[s]; | |
ll d = a[s + m] * w % mod; | |
a[s] = u + d; | |
if (a[s] >= mod) a[s] -= mod; | |
a[s + m] = u - d; | |
if (a[s + m] < 0) a[s + m] += mod; | |
} | |
w = w * base % mod; | |
} | |
} | |
for (auto& x : a) if (x < 0) x += mod; | |
} | |
void ntt(vector<ll>& input) { | |
_ntt(input, 1); | |
} | |
void intt(vector<ll>& input) { | |
_ntt(input, -1); | |
const int n_inv = mod_inv(sz(input), mod); | |
for (auto& x : input) x = x * n_inv % mod; | |
} | |
// 畳み込み演算を行う | |
vector<ll> convolution(const vector<ll>& a, const vector<ll>& b){ | |
int ntt_size = 1; | |
while (ntt_size < sz(a) + sz(b)) ntt_size *= 2; | |
vector<ll> _a = a, _b = b; | |
_a.resize(ntt_size); _b.resize(ntt_size); | |
ntt(_a); | |
ntt(_b); | |
FOR(i, ntt_size){ | |
(_a[i] *= _b[i]) %= mod; | |
} | |
intt(_a); | |
return _a; | |
} | |
}; | |
ll garner(vector<Pii> mr, int mod){ | |
mr.emplace_back(mod, 0); | |
vector<ll> coffs(sz(mr), 1); | |
vector<ll> constants(sz(mr), 0); | |
FOR(i, sz(mr) - 1){ | |
// coffs[i] * v + constants[i] == mr[i].second (mod mr[i].first) を解く | |
ll v = (mr[i].second - constants[i]) * mod_inv<ll>(coffs[i], mr[i].first) % mr[i].first; | |
if (v < 0) v += mr[i].first; | |
for (int j = i + 1; j < sz(mr); j++) { | |
(constants[j] += coffs[j] * v) %= mr[j].first; | |
(coffs[j] *= mr[i].first) %= mr[j].first; | |
} | |
} | |
return constants[sz(mr) - 1]; | |
} | |
typedef NTT<167772161, 3> NTT_1; | |
typedef NTT<469762049, 3> NTT_2; | |
typedef NTT<1224736769, 3> NTT_3; | |
//任意のmodで畳み込み演算 O(n log n) | |
vector<ll> int32mod_convolution(vector<ll> a, vector<ll> b,int mod){ | |
for (auto& x : a) x %= mod; | |
for (auto& x : b) x %= mod; | |
NTT_1 ntt1; NTT_2 ntt2; NTT_3 ntt3; | |
auto x = ntt1.convolution(a, b); | |
auto y = ntt2.convolution(a, b); | |
auto z = ntt3.convolution(a, b); | |
vector<ll> ret(sz(x)); | |
vector<Pii> mr(3); | |
FOR(i, sz(x)){ | |
mr[0].first = ntt1.get_mod(), mr[0].second = (int)x[i]; | |
mr[1].first = ntt2.get_mod(), mr[1].second = (int)y[i]; | |
mr[2].first = ntt3.get_mod(), mr[2].second = (int)z[i]; | |
ret[i] = garner(mr, mod); | |
} | |
return ret; | |
} | |
// garnerのアルゴリズムを直書きしたversion,速い | |
vector<ll> fast_int32mod_convolution(vector<ll> a, vector<ll> b,int mod){ | |
for (auto& x : a) x %= mod; | |
for (auto& x : b) x %= mod; | |
NTT_1 ntt1; NTT_2 ntt2; NTT_3 ntt3; | |
assert(ntt1.get_mod() < ntt2.get_mod() && ntt2.get_mod() < ntt3.get_mod()); | |
auto x = ntt1.convolution(a, b); | |
auto y = ntt2.convolution(a, b); | |
auto z = ntt3.convolution(a, b); | |
// garnerのアルゴリズムを極力高速化した | |
const ll m1 = ntt1.get_mod(), m2 = ntt2.get_mod(), m3 = ntt3.get_mod(); | |
const ll m1_inv_m2 = mod_inv<ll>(m1, m2); | |
const ll m12_inv_m3 = mod_inv<ll>(m1 * m2, m3); | |
const ll m12_mod = m1 * m2 % mod; | |
vector<ll> ret(sz(x)); | |
FOR(i, sz(x)){ | |
ll v1 = (y[i] - x[i]) * m1_inv_m2 % m2; | |
if (v1 < 0) v1 += m2; | |
ll v2 = (z[i] - (x[i] + m1 * v1) % m3) * m12_inv_m3 % m3; | |
if (v2 < 0) v2 += m3; | |
ll constants3 = (x[i] + m1 * v1 + m12_mod * v2) % mod; | |
if (constants3 < 0) constants3 += mod; | |
ret[i] = constants3; | |
} | |
return ret; | |
} | |
//2^23より大きく,primitive rootに3を持つもの | |
// const int mods[] = { 1224736769, 469762049, 167772161, 595591169, 645922817, 897581057, 998244353 }; | |
void ntt_test() { | |
NTT_1 ntt; | |
vector<ll> v; | |
FOR(i, 16) v.push_back(10 + i); | |
auto v2 = v; | |
ntt.ntt(v2); | |
auto v3 = v2; | |
ntt.intt(v3); | |
assert(v == v3); | |
} | |
void comvolution_test() { | |
NTT_1 ntt1; | |
vector<ll> v = { 1, 2, 3 }; | |
vector<ll> u = { 4, 5, 6 }; | |
auto vu = ntt1.convolution(v, u); | |
vector<ll> vu2 = { 1 * 4, 1 * 5 + 2 * 4, 1 * 6 + 2 * 5 + 3 * 4, 2 * 6 + 3 * 5, 3 * 6, 0, 0, 0 }; | |
assert(vu == vu2); | |
} | |
void int32mod_convolution_test(){ | |
vector<ll> x , y; | |
FOR(i, 10) x.push_back(ten(8) + i); | |
y = x; | |
auto z = int32mod_convolution(x, y, ten(9) + 7); | |
z.resize(sz(x) + sz(y) - 1); | |
vector<ll> z2 = { | |
930000007, 60000000, 390000001, 920000004, | |
650000003, 580000006, 710000014, 40000021, | |
570000042, 300000064, 370000109, 240000144, | |
910000175, 380000187, 650000193, 720000185, | |
590000162, 260000123, 730000074 }; | |
assert(z == z2); | |
} | |
void test(){ | |
ntt_test(); | |
comvolution_test(); | |
int32mod_convolution_test(); | |
} | |
int main(){ | |
test(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment