Skip to content

Instantly share code, notes, and snippets.

@OmarElawady
Created March 28, 2020 12:37
Show Gist options
  • Select an option

  • Save OmarElawady/637ec3f0f3bf1c97482f8b54a0957257 to your computer and use it in GitHub Desktop.

Select an option

Save OmarElawady/637ec3f0f3bf1c97482f8b54a0957257 to your computer and use it in GitHub Desktop.
The mysterious minus
#include <bits/stdc++.h>
#include <complex>
using namespace std;
typedef long long ll;
typedef complex<double> Complex;
const Complex J(0, 1);
int logint(int n){
int lg = 0;
while((1 << lg) < n)
lg++;
return lg;
}
int rev(int x, int lgn){
int res = 0;
while(lgn--){
res = 2 * res + (x & 1);
x /= 2;
}
return res;
}
vector<Complex> FFT(const vector<Complex>& v){
int n = v.size();
int lgn = logint(n);
assert((n & (n - 1)) == 0);
vector<Complex> perm(v.size());
for(int i = 0;i < n;i++){
perm[i] = v[rev(i, lgn)];
}
for(int s = 1;s <= lgn;s++){
int m = (1 << s);
Complex wm = exp(-2 * M_PI * J / (double)m);
for(int k = 0;k < n;k += m){
Complex w = 1;
for(int j = 0;j < m / 2;j++){
Complex t = w * perm[k + j + m / 2];
Complex u = perm[k + j];
perm[k + j] = u + t;
perm[k + j + m / 2] = u - t;
w = w * wm;
}
}
}
return perm;
}
vector<Complex> general_iFFT(const vector<Complex>& v, vector<Complex> (*f)(const vector<Complex>&)){
vector<Complex> cp = v;
for(auto& el : cp)
el = Complex(el.real(), -el.imag());
cp = (*f)(cp);
for(auto& el : cp)
el = Complex(el.real(), -el.imag()) / (double)v.size();
return cp;
}
vector<Complex> iFFT(const vector<Complex>& v){
return general_iFFT(v, &FFT);
}
vector<Complex> recursive_FFT(const vector<Complex>& v){
int n = v.size();
if(n == 1)
return v;
assert((n & (n - 1)) == 0);
Complex wn = exp(-2 * M_PI * J / (double)n);
Complex w = 1;
vector<Complex> a0(n / 2), a1(n / 2);
for(int i = 0;i < n;i += 2)
a0[i / 2] = v[i], a1[i / 2] = v[i + 1];
vector<Complex> y0, y1;
y0 = recursive_FFT(a0);
y1 = recursive_FFT(a1);
vector<Complex> y(n);
for(int k = 0;k < n / 2;k++){
y[k] = y0[k] + w * y1[k];
y[k + n / 2] = y0[k] - w * y1[k];
w = w * wn;
}
return y;
}
vector<Complex> recursive_iFFT(vector<Complex>& v){
return general_iFFT(v, &recursive_FFT);
}
Complex eval(const vector<Complex>& pol, Complex& point){
Complex result = 0;
int n = pol.size();
for(int i = 0;i < n;i++){
result *= point;
result += pol[n - 1 - i];
}
return result;
}
vector<Complex> ola_FFT(const vector<Complex>& v){
int n = v.size();
Complex wn = exp(-2 * M_PI * J / (double)n);
Complex w = 1;
vector<Complex> result(n);
for(int i = 0;i < n;i++){
result[i] = eval(v, w);
w *= wn;
}
return result;
}
vector<Complex> ola_iFFT(const vector<Complex>& v){
return general_iFFT(v, &ola_FFT);
}
const double EPS = 1e-7;
bool double_equal(double a, double b){
return abs(a - b) < EPS;
}
bool vector_equal(vector<Complex>& a, vector<Complex>& b){
if(a.size() != b.size())
return false;
for(size_t i = 0;i < a.size();i++){
if(!double_equal(a[i].real(), b[i].real())
|| !double_equal(a[i].imag(), b[i].imag()))
return false;
}
return true;
}
ostream& operator<<(ostream& out, const vector<Complex>& v){
for(auto el : v)
out << fixed << setprecision(2) << el.real() << ' ' << setprecision(2) << el.imag() << endl;
return out;
}
void test_helpers(){
assert(rev(0b101001, 6) == 0b100101);
assert(logint(0b1000) == 3);
}
void test_fft(){
cout << "==============================================================\n"
<< "Iterative FFT\n";
vector<Complex> orig;
vector<Complex> dft;
vector<Complex> idft;
orig = {1, 2, 3, 4};
dft = FFT(orig);
idft = iFFT(dft);
cout << "Original\n" << orig;
cout << "DFT\n" << dft;
cout << "IDFT\n" << idft;
assert(vector_equal(idft, orig));
}
void test_recursive_fft(){
cout << "==============================================================\n"
<< "Recursive FFT\n";
vector<Complex> orig;
vector<Complex> dft;
vector<Complex> idft;
orig = {1, 2, 3, 4};
dft = recursive_FFT(orig);
idft = recursive_iFFT(dft);
cout << "Original\n" << orig;
cout << "DFT\n" << dft;
cout << "IDFT\n" << idft;
assert(vector_equal(idft, orig));
}
void test_ola_fft(){
cout << "==============================================================\n"
<< "Ola FFT\n";
vector<Complex> orig;
vector<Complex> dft;
vector<Complex> idft;
orig = {1, 2, 3, 4};
dft = ola_FFT(orig);
idft = ola_iFFT(dft);
cout << "Original\n" << orig;
cout << "DFT\n" << dft;
cout << "IDFT\n" << idft;
assert(vector_equal(idft, orig));
}
int main()
{
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
test_helpers();
test_fft();
test_recursive_fft();
test_ola_fft();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment