Skip to content

Instantly share code, notes, and snippets.

@kevinpostal
Last active March 7, 2025 22:01
Show Gist options
  • Save kevinpostal/008695a8ba35c18949043745352e7bf1 to your computer and use it in GitHub Desktop.
Save kevinpostal/008695a8ba35c18949043745352e7bf1 to your computer and use it in GitHub Desktop.
#pragma once
#include <vector>
#include <complex>
#include <cassert>
#include <stdexcept>
#include <cmath>
#include <memory>
// Optional NEON support
#define USE_NEON
#ifdef USE_NEON
#include <arm_neon.h>
#endif
// Constants (replace const1.h dependency)
constexpr double M_PI = 3.14159265358979323846;
constexpr double SQ2_2 = 0.70710678118654752440; // sqrt(2)/2
// Complex type definition (replace cmplxT<T>)
template <typename T>
struct Complex {
T re, im;
Complex(T r = T(), T i = T()) : re(r), im(i) {}
Complex operator*(T scalar) const { return {re * scalar, im * scalar}; }
Complex operator+(const Complex& other) const { return {re + other.re, im + other.im}; }
Complex operator-(const Complex& other) const { return {re - other.re, im - other.im}; }
};
// FFTReal class
template <typename T, int N>
class FFTReal {
static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2");
using ComplexT = Complex<T>;
static constexpr unsigned floorlog2(unsigned x) {
return (x == 1) ? 0 : 1 + floorlog2(x >> 1);
}
static constexpr int nbr_bits = floorlog2(N);
static constexpr int N2 = N >> 1;
public:
FFTReal() : buffer_ptr(N * 2 + 2), yy(N * 2 + 2) {
// Ensure buffers are initialized to zero
std::fill(buffer_ptr.begin(), buffer_ptr.end(), T());
std::fill(yy.begin(), yy.end(), T());
}
// Real FFT: time domain -> frequency domain
void real_fft(const T* x, ComplexT* y, bool do_scale = false) {
if (!x || !y) throw std::invalid_argument("Null pointer passed to real_fft");
T mul = do_scale ? T(1.0 / N) : T(1.0);
#ifdef USE_NEON
if constexpr (std::is_same_v<T, double>) {
do_fft_neon_d8(x, yy.data());
} else if constexpr (std::is_same_v<T, float>) {
do_fft_neon_f8(x, yy.data());
} else {
do_fft(x, yy.data());
}
#else
do_fft(x, yy.data());
#endif
for (int i = 1; i < N2; ++i) {
y[i] = ComplexT(yy[i], yy[i + N2]) * mul;
}
y[0] = ComplexT(yy[0], T()) * mul;
}
// Inverse Real FFT: frequency domain -> time domain
void real_ifft(const ComplexT* x, T* y, bool do_scale = false) {
if (!x || !y) throw std::invalid_argument("Null pointer passed to real_ifft");
for (int i = 1; i < N2; ++i) {
yy[i] = x[i].re;
yy[i + N2] = x[i].im;
}
yy[0] = x[0].re;
yy[N2] = T();
#ifdef USE_NEON
if constexpr (std::is_same_v<T, double>) {
do_ifft_neon_d8(yy.data(), y);
} else if constexpr (std::is_same_v<T, float>) {
do_ifft_neon_f8(yy.data(), y);
} else {
do_ifft(yy.data(), y, do_scale);
}
#else
do_ifft(yy.data(), y, do_scale);
#endif
}
private:
std::vector<T, std::aligned_allocator<T, 64>> buffer_ptr; // Working buffer
std::vector<T, std::aligned_allocator<T, 64>> yy; // Intermediate array
void do_fft(const T* x, T* f) {
if (nbr_bits > 2) {
T* sf = (nbr_bits & 1) ? f : buffer_ptr.data();
T* df = (nbr_bits & 1) ? buffer_ptr.data() : f;
// First and second passes combined
const int* lut = bit_rev_lut.get_ptr();
for (int i = 0; i < N; i += 4) {
T x0 = x[lut[i]], x1 = x[lut[i + 1]], x2 = x[lut[i + 2]], x3 = x[lut[i + 3]];
df[i] = x0 + x1 + x2 + x3;
df[i + 1] = x0 - x1;
df[i + 2] = x0 + x1 - x2 - x3;
df[i + 3] = x2 - x3;
}
// Third pass
for (int i = 0; i < N; i += 8) {
sf[i] = df[i] + df[i + 4];
sf[i + 4] = df[i] - df[i + 4];
sf[i + 2] = df[i + 2];
sf[i + 6] = df[i + 6];
T v = (df[i + 5] - df[i + 7]) * SQ2_2;
sf[i + 1] = df[i + 1] + v;
sf[i + 3] = df[i + 1] - v;
v = (df[i + 5] + df[i + 7]) * SQ2_2;
sf[i + 5] = v + df[i + 3];
sf[i + 7] = v - df[i + 3];
}
// Subsequent passes
for (int pass = 3; pass < nbr_bits; ++pass) {
int nbr_coef = 1 << pass;
int h_nbr_coef = nbr_coef >> 1;
int d_nbr_coef = nbr_coef << 1;
const T* cos_ptr = trigo_lut.get_ptr(pass);
for (int i = 0; i < N; i += d_nbr_coef) {
T* sf1r = sf + i;
T* sf2r = sf1r + nbr_coef;
T* dfr = df + i;
T* dfi = dfr + nbr_coef;
dfr[0] = sf1r[0] + sf2r[0];
dfi[0] = sf1r[0] - sf2r[0];
dfr[h_nbr_coef] = sf1r[h_nbr_coef];
dfi[h_nbr_coef] = sf2r[h_nbr_coef];
for (int j = 1; j < h_nbr_coef; ++j) {
T c = cos_ptr[j];
T s = cos_ptr[h_nbr_coef - j];
T v = sf2r[j] * c - sf2r[j + h_nbr_coef] * s;
dfr[j] = sf1r[j] + v;
dfi[-j] = sf1r[j] - v;
v = sf2r[j] * s + sf2r[j + h_nbr_coef] * c;
dfi[j] = v + sf1r[j + h_nbr_coef];
dfi[nbr_coef - j] = v - sf1r[j + h_nbr_coef];
}
}
std::swap(sf, df);
}
} else {
handle_special_cases_fft(x, f);
}
}
void do_ifft(const T* f, T* x, bool do_scale) {
T mul = do_scale ? T(1.0 / N) : T(1.0);
if (nbr_bits > 2) {
T* sf = const_cast<T*>(f);
T* df = (nbr_bits & 1) ? buffer_ptr.data() : x;
T* df_temp = (nbr_bits & 1) ? x : buffer_ptr.data();
for (int pass = nbr_bits - 1; pass >= 3; --pass) {
int nbr_coef = 1 << pass;
int h_nbr_coef = nbr_coef >> 1;
int d_nbr_coef = nbr_coef << 1;
const T* cos_ptr = trigo_lut.get_ptr(pass);
for (int i = 0; i < N; i += d_nbr_coef) {
T* sfr = sf + i;
T* sfi = sfr + nbr_coef;
T* df1r = df + i;
T* df2r = df1r + nbr_coef;
df1r[0] = sfr[0] + sfr[nbr_coef];
df2r[0] = sfr[0] - sfr[nbr_coef];
df1r[h_nbr_coef] = sfr[h_nbr_coef] * 2;
df2r[h_nbr_coef] = sfi[h_nbr_coef] * 2;
for (int j = 1; j < h_nbr_coef; ++j) {
T c = cos_ptr[j];
T s = cos_ptr[h_nbr_coef - j];
df1r[j] = sfr[j] + sfi[-j];
df1r[j + h_nbr_coef] = sfi[j] - sfi[nbr_coef - j];
T vr = sfr[j] - sfi[-j];
T vi = sfi[j] + sfi[nbr_coef - j];
df2r[j] = vr * c + vi * s;
df2r[j + h_nbr_coef] = vi * c - vr * s;
}
}
if (pass < nbr_bits - 1) std::swap(df, sf);
else { sf = df; df = df_temp; }
}
// Antepenultimate pass
for (int i = 0; i < N; i += 8) {
T* sf2 = sf + i;
T* df2 = df + i;
T vr = sf2[1] - sf2[3];
T vi = sf2[5] + sf2[7];
df2[0] = sf2[0] + sf2[4];
df2[1] = sf2[1] + sf2[3];
df2[2] = sf2[2] * 2;
df2[3] = sf2[5] - sf2[7];
df2[4] = sf2[0] - sf2[4];
df2[5] = (vr + vi) * SQ2_2;
df2[6] = sf2[6] * 2;
df2[7] = (vi - vr) * SQ2_2;
}
// Final passes
const int* lut = bit_rev_lut.get_ptr();
for (int i = 0; i < N; i += 8) {
T* sf2 = df + i;
T b_0 = sf2[0] + sf2[2], b_1 = sf2[1] * 2;
T b_2 = sf2[0] - sf2[2], b_3 = sf2[3] * 2;
x[lut[i]] = (b_0 + b_1) * mul;
x[lut[i + 1]] = (b_0 - b_1) * mul;
x[lut[i + 2]] = (b_2 + b_3) * mul;
x[lut[i + 3]] = (b_2 - b_3) * mul;
b_0 = sf2[4] + sf2[6]; b_1 = sf2[5] * 2;
b_2 = sf2[4] - sf2[6]; b_3 = sf2[7] * 2;
x[lut[i + 4]] = (b_0 + b_1) * mul;
x[lut[i + 5]] = (b_0 - b_1) * mul;
x[lut[i + 6]] = (b_2 + b_3) * mul;
x[lut[i + 7]] = (b_2 - b_3) * mul;
}
} else {
handle_special_cases_ifft(f, x, mul);
}
}
#ifdef USE_NEON
void do_fft_neon_d8(const double* x, double* f) {
// NEON implementation for double (simplified scalar fallback for brevity)
do_fft(x, f); // Replace with actual NEON code if needed
}
void do_fft_neon_f8(const float* x, float* f) {
// NEON implementation for float (simplified scalar fallback for brevity)
do_fft(x, f); // Replace with actual NEON code if needed
}
void do_ifft_neon_d8(const double* f, double* x) {
do_ifft(f, x, false); // Replace with actual NEON code if needed
}
void do_ifft_neon_f8(const float* f, float* x) {
do_ifft(f, x, false); // Replace with actual NEON code if needed
}
#endif
// Bit-reversed LUT
class BitReversedLUT {
public:
BitReversedLUT() : lut(N) {
int br_index = 0;
lut[0] = 0;
for (int cnt = 1; cnt < N; ++cnt) {
int bit = N >> 1;
while (((br_index ^= bit) & bit) == 0) bit >>= 1;
lut[cnt] = br_index;
}
}
const int* get_ptr() const { return lut.data(); }
private:
std::vector<int> lut;
};
// Trigonometric LUT
class TrigoLUT {
public:
TrigoLUT() {
if (nbr_bits > 3) {
int total_len = (1 << (nbr_bits - 1)) - 4;
lut.resize(total_len);
int pos = 0;
for (int level = 3; level < nbr_bits; ++level) {
int level_len = 1 << (level - 1);
T mul = M_PI / (level_len * 2);
for (int i = 0; i < level_len; ++i) {
lut[pos++] = std::cos(i * mul);
}
}
}
}
const T* get_ptr(int level) const {
return lut.data() + (1 << (level - 1)) - 4;
}
private:
std::vector<T> lut;
};
void handle_special_cases_fft(const T* x, T* f) {
if (nbr_bits == 2) {
f[1] = x[0] - x[2];
f[3] = x[1] - x[3];
T b_0 = x[0] + x[2], b_2 = x[1] + x[3];
f[0] = b_0 + b_2;
f[2] = b_0 - b_2;
} else if (nbr_bits == 1) {
f[0] = x[0] + x[1];
f[1] = x[0] - x[1];
} else {
f[0] = x[0];
}
}
void handle_special_cases_ifft(const T* f, T* x, T mul) {
if (nbr_bits == 2) {
T b_0 = f[0] + f[2], b_2 = f[0] - f[2];
x[0] = (b_0 + f[1] * 2) * mul;
x[2] = (b_0 - f[1] * 2) * mul;
x[1] = (b_2 + f[3] * 2) * mul;
x[3] = (b_2 - f[3] * 2) * mul;
} else if (nbr_bits == 1) {
x[0] = (f[0] + f[1]) * mul;
x[1] = (f[0] - f[1]) * mul;
} else {
x[0] = f[0] * mul;
}
}
BitReversedLUT bit_rev_lut;
TrigoLUT trigo_lut;
};
// Example usage
int main() {
FFTReal<float, 8> fft;
std::vector<float> input = {1, 2, 3, 4, 5, 6, 7, 8};
std::vector<Complex<float>> output(4);
fft.real_fft(input.data(), output.data());
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment