Last active
March 7, 2025 22:01
-
-
Save kevinpostal/008695a8ba35c18949043745352e7bf1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <vector> | |
#include <complex> | |
#include <cassert> | |
#include <stdexcept> | |
#include <cmath> | |
#include <memory> | |
// Optional NEON support | |
#define USE_NEON | |
#ifdef USE_NEON | |
#include <arm_neon.h> | |
#endif | |
// Constants (replace const1.h dependency) | |
constexpr double M_PI = 3.14159265358979323846; | |
constexpr double SQ2_2 = 0.70710678118654752440; // sqrt(2)/2 | |
// Complex type definition (replace cmplxT<T>) | |
template <typename T> | |
struct Complex { | |
T re, im; | |
Complex(T r = T(), T i = T()) : re(r), im(i) {} | |
Complex operator*(T scalar) const { return {re * scalar, im * scalar}; } | |
Complex operator+(const Complex& other) const { return {re + other.re, im + other.im}; } | |
Complex operator-(const Complex& other) const { return {re - other.re, im - other.im}; } | |
}; | |
// FFTReal class | |
template <typename T, int N> | |
class FFTReal { | |
static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2"); | |
using ComplexT = Complex<T>; | |
static constexpr unsigned floorlog2(unsigned x) { | |
return (x == 1) ? 0 : 1 + floorlog2(x >> 1); | |
} | |
static constexpr int nbr_bits = floorlog2(N); | |
static constexpr int N2 = N >> 1; | |
public: | |
FFTReal() : buffer_ptr(N * 2 + 2), yy(N * 2 + 2) { | |
// Ensure buffers are initialized to zero | |
std::fill(buffer_ptr.begin(), buffer_ptr.end(), T()); | |
std::fill(yy.begin(), yy.end(), T()); | |
} | |
// Real FFT: time domain -> frequency domain | |
void real_fft(const T* x, ComplexT* y, bool do_scale = false) { | |
if (!x || !y) throw std::invalid_argument("Null pointer passed to real_fft"); | |
T mul = do_scale ? T(1.0 / N) : T(1.0); | |
#ifdef USE_NEON | |
if constexpr (std::is_same_v<T, double>) { | |
do_fft_neon_d8(x, yy.data()); | |
} else if constexpr (std::is_same_v<T, float>) { | |
do_fft_neon_f8(x, yy.data()); | |
} else { | |
do_fft(x, yy.data()); | |
} | |
#else | |
do_fft(x, yy.data()); | |
#endif | |
for (int i = 1; i < N2; ++i) { | |
y[i] = ComplexT(yy[i], yy[i + N2]) * mul; | |
} | |
y[0] = ComplexT(yy[0], T()) * mul; | |
} | |
// Inverse Real FFT: frequency domain -> time domain | |
void real_ifft(const ComplexT* x, T* y, bool do_scale = false) { | |
if (!x || !y) throw std::invalid_argument("Null pointer passed to real_ifft"); | |
for (int i = 1; i < N2; ++i) { | |
yy[i] = x[i].re; | |
yy[i + N2] = x[i].im; | |
} | |
yy[0] = x[0].re; | |
yy[N2] = T(); | |
#ifdef USE_NEON | |
if constexpr (std::is_same_v<T, double>) { | |
do_ifft_neon_d8(yy.data(), y); | |
} else if constexpr (std::is_same_v<T, float>) { | |
do_ifft_neon_f8(yy.data(), y); | |
} else { | |
do_ifft(yy.data(), y, do_scale); | |
} | |
#else | |
do_ifft(yy.data(), y, do_scale); | |
#endif | |
} | |
private: | |
std::vector<T, std::aligned_allocator<T, 64>> buffer_ptr; // Working buffer | |
std::vector<T, std::aligned_allocator<T, 64>> yy; // Intermediate array | |
void do_fft(const T* x, T* f) { | |
if (nbr_bits > 2) { | |
T* sf = (nbr_bits & 1) ? f : buffer_ptr.data(); | |
T* df = (nbr_bits & 1) ? buffer_ptr.data() : f; | |
// First and second passes combined | |
const int* lut = bit_rev_lut.get_ptr(); | |
for (int i = 0; i < N; i += 4) { | |
T x0 = x[lut[i]], x1 = x[lut[i + 1]], x2 = x[lut[i + 2]], x3 = x[lut[i + 3]]; | |
df[i] = x0 + x1 + x2 + x3; | |
df[i + 1] = x0 - x1; | |
df[i + 2] = x0 + x1 - x2 - x3; | |
df[i + 3] = x2 - x3; | |
} | |
// Third pass | |
for (int i = 0; i < N; i += 8) { | |
sf[i] = df[i] + df[i + 4]; | |
sf[i + 4] = df[i] - df[i + 4]; | |
sf[i + 2] = df[i + 2]; | |
sf[i + 6] = df[i + 6]; | |
T v = (df[i + 5] - df[i + 7]) * SQ2_2; | |
sf[i + 1] = df[i + 1] + v; | |
sf[i + 3] = df[i + 1] - v; | |
v = (df[i + 5] + df[i + 7]) * SQ2_2; | |
sf[i + 5] = v + df[i + 3]; | |
sf[i + 7] = v - df[i + 3]; | |
} | |
// Subsequent passes | |
for (int pass = 3; pass < nbr_bits; ++pass) { | |
int nbr_coef = 1 << pass; | |
int h_nbr_coef = nbr_coef >> 1; | |
int d_nbr_coef = nbr_coef << 1; | |
const T* cos_ptr = trigo_lut.get_ptr(pass); | |
for (int i = 0; i < N; i += d_nbr_coef) { | |
T* sf1r = sf + i; | |
T* sf2r = sf1r + nbr_coef; | |
T* dfr = df + i; | |
T* dfi = dfr + nbr_coef; | |
dfr[0] = sf1r[0] + sf2r[0]; | |
dfi[0] = sf1r[0] - sf2r[0]; | |
dfr[h_nbr_coef] = sf1r[h_nbr_coef]; | |
dfi[h_nbr_coef] = sf2r[h_nbr_coef]; | |
for (int j = 1; j < h_nbr_coef; ++j) { | |
T c = cos_ptr[j]; | |
T s = cos_ptr[h_nbr_coef - j]; | |
T v = sf2r[j] * c - sf2r[j + h_nbr_coef] * s; | |
dfr[j] = sf1r[j] + v; | |
dfi[-j] = sf1r[j] - v; | |
v = sf2r[j] * s + sf2r[j + h_nbr_coef] * c; | |
dfi[j] = v + sf1r[j + h_nbr_coef]; | |
dfi[nbr_coef - j] = v - sf1r[j + h_nbr_coef]; | |
} | |
} | |
std::swap(sf, df); | |
} | |
} else { | |
handle_special_cases_fft(x, f); | |
} | |
} | |
void do_ifft(const T* f, T* x, bool do_scale) { | |
T mul = do_scale ? T(1.0 / N) : T(1.0); | |
if (nbr_bits > 2) { | |
T* sf = const_cast<T*>(f); | |
T* df = (nbr_bits & 1) ? buffer_ptr.data() : x; | |
T* df_temp = (nbr_bits & 1) ? x : buffer_ptr.data(); | |
for (int pass = nbr_bits - 1; pass >= 3; --pass) { | |
int nbr_coef = 1 << pass; | |
int h_nbr_coef = nbr_coef >> 1; | |
int d_nbr_coef = nbr_coef << 1; | |
const T* cos_ptr = trigo_lut.get_ptr(pass); | |
for (int i = 0; i < N; i += d_nbr_coef) { | |
T* sfr = sf + i; | |
T* sfi = sfr + nbr_coef; | |
T* df1r = df + i; | |
T* df2r = df1r + nbr_coef; | |
df1r[0] = sfr[0] + sfr[nbr_coef]; | |
df2r[0] = sfr[0] - sfr[nbr_coef]; | |
df1r[h_nbr_coef] = sfr[h_nbr_coef] * 2; | |
df2r[h_nbr_coef] = sfi[h_nbr_coef] * 2; | |
for (int j = 1; j < h_nbr_coef; ++j) { | |
T c = cos_ptr[j]; | |
T s = cos_ptr[h_nbr_coef - j]; | |
df1r[j] = sfr[j] + sfi[-j]; | |
df1r[j + h_nbr_coef] = sfi[j] - sfi[nbr_coef - j]; | |
T vr = sfr[j] - sfi[-j]; | |
T vi = sfi[j] + sfi[nbr_coef - j]; | |
df2r[j] = vr * c + vi * s; | |
df2r[j + h_nbr_coef] = vi * c - vr * s; | |
} | |
} | |
if (pass < nbr_bits - 1) std::swap(df, sf); | |
else { sf = df; df = df_temp; } | |
} | |
// Antepenultimate pass | |
for (int i = 0; i < N; i += 8) { | |
T* sf2 = sf + i; | |
T* df2 = df + i; | |
T vr = sf2[1] - sf2[3]; | |
T vi = sf2[5] + sf2[7]; | |
df2[0] = sf2[0] + sf2[4]; | |
df2[1] = sf2[1] + sf2[3]; | |
df2[2] = sf2[2] * 2; | |
df2[3] = sf2[5] - sf2[7]; | |
df2[4] = sf2[0] - sf2[4]; | |
df2[5] = (vr + vi) * SQ2_2; | |
df2[6] = sf2[6] * 2; | |
df2[7] = (vi - vr) * SQ2_2; | |
} | |
// Final passes | |
const int* lut = bit_rev_lut.get_ptr(); | |
for (int i = 0; i < N; i += 8) { | |
T* sf2 = df + i; | |
T b_0 = sf2[0] + sf2[2], b_1 = sf2[1] * 2; | |
T b_2 = sf2[0] - sf2[2], b_3 = sf2[3] * 2; | |
x[lut[i]] = (b_0 + b_1) * mul; | |
x[lut[i + 1]] = (b_0 - b_1) * mul; | |
x[lut[i + 2]] = (b_2 + b_3) * mul; | |
x[lut[i + 3]] = (b_2 - b_3) * mul; | |
b_0 = sf2[4] + sf2[6]; b_1 = sf2[5] * 2; | |
b_2 = sf2[4] - sf2[6]; b_3 = sf2[7] * 2; | |
x[lut[i + 4]] = (b_0 + b_1) * mul; | |
x[lut[i + 5]] = (b_0 - b_1) * mul; | |
x[lut[i + 6]] = (b_2 + b_3) * mul; | |
x[lut[i + 7]] = (b_2 - b_3) * mul; | |
} | |
} else { | |
handle_special_cases_ifft(f, x, mul); | |
} | |
} | |
#ifdef USE_NEON | |
void do_fft_neon_d8(const double* x, double* f) { | |
// NEON implementation for double (simplified scalar fallback for brevity) | |
do_fft(x, f); // Replace with actual NEON code if needed | |
} | |
void do_fft_neon_f8(const float* x, float* f) { | |
// NEON implementation for float (simplified scalar fallback for brevity) | |
do_fft(x, f); // Replace with actual NEON code if needed | |
} | |
void do_ifft_neon_d8(const double* f, double* x) { | |
do_ifft(f, x, false); // Replace with actual NEON code if needed | |
} | |
void do_ifft_neon_f8(const float* f, float* x) { | |
do_ifft(f, x, false); // Replace with actual NEON code if needed | |
} | |
#endif | |
// Bit-reversed LUT | |
class BitReversedLUT { | |
public: | |
BitReversedLUT() : lut(N) { | |
int br_index = 0; | |
lut[0] = 0; | |
for (int cnt = 1; cnt < N; ++cnt) { | |
int bit = N >> 1; | |
while (((br_index ^= bit) & bit) == 0) bit >>= 1; | |
lut[cnt] = br_index; | |
} | |
} | |
const int* get_ptr() const { return lut.data(); } | |
private: | |
std::vector<int> lut; | |
}; | |
// Trigonometric LUT | |
class TrigoLUT { | |
public: | |
TrigoLUT() { | |
if (nbr_bits > 3) { | |
int total_len = (1 << (nbr_bits - 1)) - 4; | |
lut.resize(total_len); | |
int pos = 0; | |
for (int level = 3; level < nbr_bits; ++level) { | |
int level_len = 1 << (level - 1); | |
T mul = M_PI / (level_len * 2); | |
for (int i = 0; i < level_len; ++i) { | |
lut[pos++] = std::cos(i * mul); | |
} | |
} | |
} | |
} | |
const T* get_ptr(int level) const { | |
return lut.data() + (1 << (level - 1)) - 4; | |
} | |
private: | |
std::vector<T> lut; | |
}; | |
void handle_special_cases_fft(const T* x, T* f) { | |
if (nbr_bits == 2) { | |
f[1] = x[0] - x[2]; | |
f[3] = x[1] - x[3]; | |
T b_0 = x[0] + x[2], b_2 = x[1] + x[3]; | |
f[0] = b_0 + b_2; | |
f[2] = b_0 - b_2; | |
} else if (nbr_bits == 1) { | |
f[0] = x[0] + x[1]; | |
f[1] = x[0] - x[1]; | |
} else { | |
f[0] = x[0]; | |
} | |
} | |
void handle_special_cases_ifft(const T* f, T* x, T mul) { | |
if (nbr_bits == 2) { | |
T b_0 = f[0] + f[2], b_2 = f[0] - f[2]; | |
x[0] = (b_0 + f[1] * 2) * mul; | |
x[2] = (b_0 - f[1] * 2) * mul; | |
x[1] = (b_2 + f[3] * 2) * mul; | |
x[3] = (b_2 - f[3] * 2) * mul; | |
} else if (nbr_bits == 1) { | |
x[0] = (f[0] + f[1]) * mul; | |
x[1] = (f[0] - f[1]) * mul; | |
} else { | |
x[0] = f[0] * mul; | |
} | |
} | |
BitReversedLUT bit_rev_lut; | |
TrigoLUT trigo_lut; | |
}; | |
// Example usage | |
int main() { | |
FFTReal<float, 8> fft; | |
std::vector<float> input = {1, 2, 3, 4, 5, 6, 7, 8}; | |
std::vector<Complex<float>> output(4); | |
fft.real_fft(input.data(), output.data()); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment