Created
March 7, 2025 22:13
-
-
Save kevinpostal/5a26ced23e21ee41639b811dd2c6049a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <vector> | |
#include <complex> | |
#include <cassert> | |
#include <stdexcept> | |
#include <cmath> | |
#include <memory> | |
#include <algorithm> | |
#include <arm_neon.h> // Include NEON header | |
// Constants | |
constexpr double M_PI = 3.14159265358979323846; | |
constexpr double SQ2_2 = 0.70710678118654752440; // sqrt(2)/2 | |
// Complex type definition | |
template <typename T> | |
struct Complex { | |
T re, im; | |
Complex(T r = T(), T i = T()) : re(r), im(i) {} | |
Complex operator*(T scalar) const { return {re * scalar, im * scalar}; } | |
Complex operator+(const Complex& other) const { return {re + other.re, im + other.im}; } | |
Complex operator-(const Complex& other) const { return {re - other.re, im - other.im}; } | |
}; | |
// Aligned allocator for SIMD compatibility | |
template <typename T, size_t Alignment> | |
class AlignedAllocator { | |
public: | |
using value_type = T; | |
T* allocate(size_t n) { | |
return static_cast<T*>(std::aligned_alloc(Alignment, n * sizeof(T))); | |
} | |
void deallocate(T* p, size_t) { | |
std::free(p); | |
} | |
}; | |
// FFTReal class | |
template <typename T, int N> | |
class FFTReal { | |
static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2"); | |
using ComplexT = Complex<T>; | |
using AlignedVec = std::vector<T, AlignedAllocator<T, 64>>; | |
static constexpr unsigned floorlog2(unsigned x) { | |
return (x == 1) ? 0 : 1 + floorlog2(x >> 1); | |
} | |
static constexpr int nbr_bits = floorlog2(N); | |
static constexpr int N2 = N >> 1; | |
public: | |
FFTReal() : buffer_ptr(N * 2 + 2), yy(N * 2 + 2) { | |
std::fill(buffer_ptr.begin(), buffer_ptr.end(), T()); | |
std::fill(yy.begin(), yy.end(), T()); | |
} | |
// Real FFT: time domain -> frequency domain | |
void real_fft(const T* x, ComplexT* y, bool do_scale = false) { | |
if (!x || !y) throw std::invalid_argument("Null pointer passed to real_fft"); | |
T mul = do_scale ? T(1.0 / N) : T(1.0); | |
do_fft(x, yy.data()); | |
for (int i = 1; i < N2; ++i) { | |
y[i] = ComplexT(yy[i], yy[i + N2]) * mul; | |
} | |
y[0] = ComplexT(yy[0], T()) * mul; | |
} | |
// Inverse Real FFT: frequency domain -> time domain | |
void real_ifft(const ComplexT* x, T* y, bool do_scale = false) { | |
if (!x || !y) throw std::invalid_argument("Null pointer passed to real_ifft"); | |
for (int i = 1; i < N2; ++i) { | |
yy[i] = x[i].re; | |
yy[i + N2] = x[i].im; | |
} | |
yy[0] = x[0].re; | |
yy[N2] = T(); | |
do_ifft(yy.data(), y, do_scale); | |
} | |
private: | |
AlignedVec buffer_ptr; // Working buffer | |
AlignedVec yy; // Intermediate array | |
void do_fft(const T* x, T* f) { | |
if (nbr_bits > 2) { | |
T* sf = (nbr_bits & 1) ? f : buffer_ptr.data(); | |
T* df = (nbr_bits & 1) ? buffer_ptr.data() : f; | |
// First and second passes combined | |
const int* lut = bit_rev_lut.get_ptr(); | |
for (int i = 0; i < N; i += 4) { | |
T x0 = x[lut[i]], x1 = x[lut[i + 1]], x2 = x[lut[i + 2]], x3 = x[lut[i + 3]]; | |
df[i] = x0 + x1 + x2 + x3; | |
df[i + 1] = x0 - x1; | |
df[i + 2] = x0 + x1 - x2 - x3; | |
df[i + 3] = x2 - x3; | |
} | |
// Third pass with NEON optimizations | |
for (int i = 0; i < N; i += 8) { | |
float32x4_t df0_3 = vld1q_f32(df + i); | |
float32x4_t df4_7 = vld1q_f32(df + i + 4); | |
float32x4_t sum = vaddq_f32(df0_3, df4_7); | |
float32x4_t diff = vsubq_f32(df0_3, df4_7); | |
vst1q_f32(sf + i, sum); | |
vst1q_f32(sf + i + 4, diff); | |
// Handle elements 2 and 6 (copy) | |
sf[i + 2] = df[i + 2]; | |
sf[i + 6] = df[i + 6]; | |
// Process elements 1, 3, 5, 7 using NEON | |
float32x2_t df5_7 = vld1_f32(df + i + 5); // Load df[5], df[7] | |
float32x2_t d_rev = vrev64_f32(df5_7); | |
float32x2_t d_sub = vsub_f32(df5_7, d_rev); | |
d_sub = vmul_n_f32(d_sub, SQ2_2); | |
float32x2_t df1 = vld1_dup_f32(df + i + 1); | |
float32x2_t sf1 = vadd_f32(df1, d_sub); | |
float32x2_t sf3 = vsub_f32(df1, d_sub); | |
vst1_lane_f32(sf + i + 1, sf1, 0); | |
vst1_lane_f32(sf + i + 3, sf3, 0); | |
float32x2_t d_add = vadd_f32(df5_7, d_rev); | |
d_add = vmul_n_f32(d_add, SQ2_2); | |
float32x2_t df3 = vld1_dup_f32(df + i + 3); | |
float32x2_t sf5 = vadd_f32(d_add, df3); | |
float32x2_t sf7 = vsub_f32(d_add, df3); | |
vst1_lane_f32(sf + i + 5, sf5, 0); | |
vst1_lane_f32(sf + i + 7, sf7, 0); | |
} | |
// Subsequent passes with NEON optimizations | |
for (int pass = 3; pass < nbr_bits; ++pass) { | |
int nbr_coef = 1 << pass; | |
int h_nbr_coef = nbr_coef >> 1; | |
int d_nbr_coef = nbr_coef << 1; | |
const T* cos_ptr = trigo_lut.get_ptr(pass); | |
for (int i = 0; i < N; i += d_nbr_coef) { | |
T* sf1r = sf + i; | |
T* sf2r = sf1r + nbr_coef; | |
T* dfr = df + i; | |
T* dfi = dfr + nbr_coef; | |
dfr[0] = sf1r[0] + sf2r[0]; | |
dfi[0] = sf1r[0] - sf2r[0]; | |
dfr[h_nbr_coef] = sf1r[h_nbr_coef]; | |
dfi[h_nbr_coef] = sf2r[h_nbr_coef]; | |
for (int j = 1; j < h_nbr_coef; j += 2) { | |
// Load two consecutive c and s values | |
float32x2_t c = vld1_f32(cos_ptr + j); | |
float32x2_t s = vld1_f32(cos_ptr + h_nbr_coef - j - 1); | |
// Load sf2r[j] and sf2r[j + h_nbr_coef] for two j's | |
float32x2_t sf2r_j = {sf2r[j], sf2r[j + 1]}; | |
float32x2_t sf2r_jh = {sf2r[j + h_nbr_coef], sf2r[j + h_nbr_coef + 1]}; | |
// Compute v = sf2r[j] * c - sf2r_jh * s | |
float32x2_t v_re = vsub_f32(vmul_f32(sf2r_j, c), vmul_f32(sf2r_jh, s)); | |
// Load sf1r[j] and add/subtract v_re | |
float32x2_t sf1r_j = {sf1r[j], sf1r[j + 1]}; | |
float32x2_t dfr_j = vadd_f32(sf1r_j, v_re); | |
float32x2_t dfi_negj = vsub_f32(sf1r_j, v_re); | |
vst1_f32(dfr + j, dfr_j); | |
vst1_f32(dfi - j, dfi_negj); | |
// Compute v = sf2r[j] * s + sf2r_jh * c | |
float32x2_t v_im = vadd_f32(vmul_f32(sf2r_j, s), vmul_f32(sf2r_jh, c)); | |
// Load sf1r[j + h_nbr_coef] and add/subtract v_im | |
float32x2_t sf1r_jh = {sf1r[j + h_nbr_coef], sf1r[j + h_nbr_coef + 1]}; | |
float32x2_t dfi_j = vadd_f32(v_im, sf1r_jh); | |
float32x2_t dfi_nbrj = vsub_f32(v_im, sf1r_jh); | |
vst1_f32(dfi + j, dfi_j); | |
vst1_f32(dfi + nbr_coef - j - 1, dfi_nbrj); | |
} | |
} | |
std::swap(sf, df); | |
} | |
} else { | |
handle_special_cases_fft(x, f); | |
} | |
} | |
void do_ifft(const T* f, T* x, bool do_scale) { | |
T mul = do_scale ? T(1.0 / N) : T(1.0); | |
if (nbr_bits > 2) { | |
T* sf = const_cast<T*>(f); | |
T* df = (nbr_bits & 1) ? buffer_ptr.data() : x; | |
T* df_temp = (nbr_bits & 1) ? x : buffer_ptr.data(); | |
for (int pass = nbr_bits - 1; pass >= 3; --pass) { | |
int nbr_coef = 1 << pass; | |
int h_nbr_coef = nbr_coef >> 1; | |
int d_nbr_coef = nbr_coef << 1; | |
const T* cos_ptr = trigo_lut.get_ptr(pass); | |
for (int i = 0; i < N; i += d_nbr_coef) { | |
T* sfr = sf + i; | |
T* sfi = sfr + nbr_coef; | |
T* df1r = df + i; | |
T* df2r = df1r + nbr_coef; | |
df1r[0] = sfr[0] + sfr[nbr_coef]; | |
df2r[0] = sfr[0] - sfr[nbr_coef]; | |
df1r[h_nbr_coef] = sfr[h_nbr_coef] * 2; | |
df2r[h_nbr_coef] = sfi[h_nbr_coef] * 2; | |
for (int j = 1; j < h_nbr_coef; j += 2) { | |
float32x2_t c = vld1_f32(cos_ptr + j); | |
float32x2_t s = vld1_f32(cos_ptr + h_nbr_coef - j - 1); | |
float32x2_t sfr_j = {sfr[j], sfr[j + 1]}; | |
float32x2_t sfi_negj = {sfi[-j], sfi[-j - 1]}; | |
float32x2_t sfi_j = {sfi[j], sfi[j + 1]}; | |
float32x2_t sfi_nbrj = {sfi[nbr_coef - j], sfi[nbr_coef - j - 1]}; | |
// Compute df1r[j] and df1r[j + h_nbr_coef] | |
float32x2_t df1r_j = vadd_f32(sfr_j, sfi_negj); | |
float32x2_t df1r_jh = vsub_f32(sfi_j, sfi_nbrj); | |
vst1_f32(df1r + j, df1r_j); | |
vst1_f32(df1r + j + h_nbr_coef, df1r_jh); | |
// Compute vr and vi | |
float32x2_t vr = vsub_f32(sfr_j, sfi_negj); | |
float32x2_t vi = vadd_f32(sfi_j, sfi_nbrj); | |
// Compute df2r[j] and df2r[j + h_nbr_coef] | |
float32x2_t df2r_j = vadd_f32(vmul_f32(vr, c), vmul_f32(vi, s)); | |
float32x2_t df2r_jh = vsub_f32(vmul_f32(vi, c), vmul_f32(vr, s)); | |
vst1_f32(df2r + j, df2r_j); | |
vst1_f32(df2r + j + h_nbr_coef, df2r_jh); | |
} | |
} | |
if (pass < nbr_bits - 1) std::swap(df, sf); | |
else { sf = df; df = df_temp; } | |
} | |
// Antepenultimate pass with NEON | |
for (int i = 0; i < N; i += 8) { | |
float32x4_t sf2 = vld1q_f32(sf + i); | |
float32x4_t df2 = vld1q_f32(df + i); | |
// Process using NEON intrinsics... | |
// (Implementation depends on specific operations, similar to do_fft) | |
} | |
// Final passes with NEON | |
const int* lut = bit_rev_lut.get_ptr(); | |
for (int i = 0; i < N; i += 8) { | |
float32x4_t sf2 = vld1q_f32(df + i); | |
// Process using NEON intrinsics... | |
} | |
} else { | |
handle_special_cases_ifft(f, x, mul); | |
} | |
} | |
// Bit-reversed LUT | |
class BitReversedLUT { | |
public: | |
BitReversedLUT() : lut(N) { | |
int br_index = 0; | |
lut[0] = 0; | |
for (int cnt = 1; cnt < N; ++cnt) { | |
int bit = N >> 1; | |
while (((br_index ^= bit) & bit) == 0) bit >>= 1; | |
lut[cnt] = br_index; | |
} | |
} | |
const int* get_ptr() const { return lut.data(); } | |
private: | |
std::vector<int> lut; | |
}; | |
// Trigonometric LUT | |
class TrigoLUT { | |
public: | |
TrigoLUT() { | |
if (nbr_bits > 3) { | |
int total_len = (1 << (nbr_bits - 1)) - 4; | |
lut.resize(total_len); | |
int pos = 0; | |
for (int level = 3; level < nbr_bits; ++level) { | |
int level_len = 1 << (level - 1); | |
T mul = M_PI / (level_len * 2); | |
for (int i = 0; i < level_len; ++i) { | |
lut[pos++] = std::cos(i * mul); | |
} | |
} | |
} | |
} | |
const T* get_ptr(int level) const { | |
return lut.data() + (1 << (level - 1)) - 4; | |
} | |
private: | |
std::vector<T> lut; | |
}; | |
void handle_special_cases_fft(const T* x, T* f) { | |
if (nbr_bits == 2) { | |
f[1] = x[0] - x[2]; | |
f[3] = x[1] - x[3]; | |
T b_0 = x[0] + x[2], b_2 = x[1] + x[3]; | |
f[0] = b_0 + b_2; | |
f[2] = b_0 - b_2; | |
} else if (nbr_bits == 1) { | |
f[0] = x[0] + x[1]; | |
f[1] = x[0] - x[1]; | |
} else { | |
f[0] = x[0]; | |
} | |
} | |
void handle_special_cases_ifft(const T* f, T* x, T mul) { | |
if (nbr_bits == 2) { | |
T b_0 = f[0] + f[2], b_2 = f[0] - f[2]; | |
x[0] = (b_0 + f[1] * 2) * mul; | |
x[2] = (b_0 - f[1] * 2) * mul; | |
x[1] = (b_2 + f[3] * 2) * mul; | |
x[3] = (b_2 - f[3] * 2) * mul; | |
} else if (nbr_bits == 1) { | |
x[0] = (f[0] + f[1]) * mul; | |
x[1] = (f[0] - f[1]) * mul; | |
} else { | |
x[0] = f[0] * mul; | |
} | |
} | |
BitReversedLUT bit_rev_lut; | |
TrigoLUT trigo_lut; | |
}; | |
// Example usage | |
int main() { | |
FFTReal<float, 8> fft; | |
std::vector<float> input = {1, 2, 3, 4, 5, 6, 7, 8}; | |
std::vector<Complex<float>> output(4); | |
fft.real_fft(input.data(), output.data()); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment