Skip to content

Instantly share code, notes, and snippets.

@kevinpostal
Created March 7, 2025 22:13
Show Gist options
  • Save kevinpostal/5a26ced23e21ee41639b811dd2c6049a to your computer and use it in GitHub Desktop.
Save kevinpostal/5a26ced23e21ee41639b811dd2c6049a to your computer and use it in GitHub Desktop.
#pragma once
#include <vector>
#include <complex>
#include <cassert>
#include <stdexcept>
#include <cmath>
#include <memory>
#include <algorithm>
#include <arm_neon.h> // Include NEON header
// Constants
constexpr double M_PI = 3.14159265358979323846;
constexpr double SQ2_2 = 0.70710678118654752440; // sqrt(2)/2
// Complex type definition
template <typename T>
struct Complex {
T re, im;
Complex(T r = T(), T i = T()) : re(r), im(i) {}
Complex operator*(T scalar) const { return {re * scalar, im * scalar}; }
Complex operator+(const Complex& other) const { return {re + other.re, im + other.im}; }
Complex operator-(const Complex& other) const { return {re - other.re, im - other.im}; }
};
// Aligned allocator for SIMD compatibility
template <typename T, size_t Alignment>
class AlignedAllocator {
public:
using value_type = T;
T* allocate(size_t n) {
return static_cast<T*>(std::aligned_alloc(Alignment, n * sizeof(T)));
}
void deallocate(T* p, size_t) {
std::free(p);
}
};
// FFTReal class
template <typename T, int N>
class FFTReal {
static_assert(N > 0 && (N & (N - 1)) == 0, "N must be a power of 2");
using ComplexT = Complex<T>;
using AlignedVec = std::vector<T, AlignedAllocator<T, 64>>;
static constexpr unsigned floorlog2(unsigned x) {
return (x == 1) ? 0 : 1 + floorlog2(x >> 1);
}
static constexpr int nbr_bits = floorlog2(N);
static constexpr int N2 = N >> 1;
public:
FFTReal() : buffer_ptr(N * 2 + 2), yy(N * 2 + 2) {
std::fill(buffer_ptr.begin(), buffer_ptr.end(), T());
std::fill(yy.begin(), yy.end(), T());
}
// Real FFT: time domain -> frequency domain
void real_fft(const T* x, ComplexT* y, bool do_scale = false) {
if (!x || !y) throw std::invalid_argument("Null pointer passed to real_fft");
T mul = do_scale ? T(1.0 / N) : T(1.0);
do_fft(x, yy.data());
for (int i = 1; i < N2; ++i) {
y[i] = ComplexT(yy[i], yy[i + N2]) * mul;
}
y[0] = ComplexT(yy[0], T()) * mul;
}
// Inverse Real FFT: frequency domain -> time domain
void real_ifft(const ComplexT* x, T* y, bool do_scale = false) {
if (!x || !y) throw std::invalid_argument("Null pointer passed to real_ifft");
for (int i = 1; i < N2; ++i) {
yy[i] = x[i].re;
yy[i + N2] = x[i].im;
}
yy[0] = x[0].re;
yy[N2] = T();
do_ifft(yy.data(), y, do_scale);
}
private:
AlignedVec buffer_ptr; // Working buffer
AlignedVec yy; // Intermediate array
void do_fft(const T* x, T* f) {
if (nbr_bits > 2) {
T* sf = (nbr_bits & 1) ? f : buffer_ptr.data();
T* df = (nbr_bits & 1) ? buffer_ptr.data() : f;
// First and second passes combined
const int* lut = bit_rev_lut.get_ptr();
for (int i = 0; i < N; i += 4) {
T x0 = x[lut[i]], x1 = x[lut[i + 1]], x2 = x[lut[i + 2]], x3 = x[lut[i + 3]];
df[i] = x0 + x1 + x2 + x3;
df[i + 1] = x0 - x1;
df[i + 2] = x0 + x1 - x2 - x3;
df[i + 3] = x2 - x3;
}
// Third pass with NEON optimizations
for (int i = 0; i < N; i += 8) {
float32x4_t df0_3 = vld1q_f32(df + i);
float32x4_t df4_7 = vld1q_f32(df + i + 4);
float32x4_t sum = vaddq_f32(df0_3, df4_7);
float32x4_t diff = vsubq_f32(df0_3, df4_7);
vst1q_f32(sf + i, sum);
vst1q_f32(sf + i + 4, diff);
// Handle elements 2 and 6 (copy)
sf[i + 2] = df[i + 2];
sf[i + 6] = df[i + 6];
// Process elements 1, 3, 5, 7 using NEON
float32x2_t df5_7 = vld1_f32(df + i + 5); // Load df[5], df[7]
float32x2_t d_rev = vrev64_f32(df5_7);
float32x2_t d_sub = vsub_f32(df5_7, d_rev);
d_sub = vmul_n_f32(d_sub, SQ2_2);
float32x2_t df1 = vld1_dup_f32(df + i + 1);
float32x2_t sf1 = vadd_f32(df1, d_sub);
float32x2_t sf3 = vsub_f32(df1, d_sub);
vst1_lane_f32(sf + i + 1, sf1, 0);
vst1_lane_f32(sf + i + 3, sf3, 0);
float32x2_t d_add = vadd_f32(df5_7, d_rev);
d_add = vmul_n_f32(d_add, SQ2_2);
float32x2_t df3 = vld1_dup_f32(df + i + 3);
float32x2_t sf5 = vadd_f32(d_add, df3);
float32x2_t sf7 = vsub_f32(d_add, df3);
vst1_lane_f32(sf + i + 5, sf5, 0);
vst1_lane_f32(sf + i + 7, sf7, 0);
}
// Subsequent passes with NEON optimizations
for (int pass = 3; pass < nbr_bits; ++pass) {
int nbr_coef = 1 << pass;
int h_nbr_coef = nbr_coef >> 1;
int d_nbr_coef = nbr_coef << 1;
const T* cos_ptr = trigo_lut.get_ptr(pass);
for (int i = 0; i < N; i += d_nbr_coef) {
T* sf1r = sf + i;
T* sf2r = sf1r + nbr_coef;
T* dfr = df + i;
T* dfi = dfr + nbr_coef;
dfr[0] = sf1r[0] + sf2r[0];
dfi[0] = sf1r[0] - sf2r[0];
dfr[h_nbr_coef] = sf1r[h_nbr_coef];
dfi[h_nbr_coef] = sf2r[h_nbr_coef];
for (int j = 1; j < h_nbr_coef; j += 2) {
// Load two consecutive c and s values
float32x2_t c = vld1_f32(cos_ptr + j);
float32x2_t s = vld1_f32(cos_ptr + h_nbr_coef - j - 1);
// Load sf2r[j] and sf2r[j + h_nbr_coef] for two j's
float32x2_t sf2r_j = {sf2r[j], sf2r[j + 1]};
float32x2_t sf2r_jh = {sf2r[j + h_nbr_coef], sf2r[j + h_nbr_coef + 1]};
// Compute v = sf2r[j] * c - sf2r_jh * s
float32x2_t v_re = vsub_f32(vmul_f32(sf2r_j, c), vmul_f32(sf2r_jh, s));
// Load sf1r[j] and add/subtract v_re
float32x2_t sf1r_j = {sf1r[j], sf1r[j + 1]};
float32x2_t dfr_j = vadd_f32(sf1r_j, v_re);
float32x2_t dfi_negj = vsub_f32(sf1r_j, v_re);
vst1_f32(dfr + j, dfr_j);
vst1_f32(dfi - j, dfi_negj);
// Compute v = sf2r[j] * s + sf2r_jh * c
float32x2_t v_im = vadd_f32(vmul_f32(sf2r_j, s), vmul_f32(sf2r_jh, c));
// Load sf1r[j + h_nbr_coef] and add/subtract v_im
float32x2_t sf1r_jh = {sf1r[j + h_nbr_coef], sf1r[j + h_nbr_coef + 1]};
float32x2_t dfi_j = vadd_f32(v_im, sf1r_jh);
float32x2_t dfi_nbrj = vsub_f32(v_im, sf1r_jh);
vst1_f32(dfi + j, dfi_j);
vst1_f32(dfi + nbr_coef - j - 1, dfi_nbrj);
}
}
std::swap(sf, df);
}
} else {
handle_special_cases_fft(x, f);
}
}
void do_ifft(const T* f, T* x, bool do_scale) {
T mul = do_scale ? T(1.0 / N) : T(1.0);
if (nbr_bits > 2) {
T* sf = const_cast<T*>(f);
T* df = (nbr_bits & 1) ? buffer_ptr.data() : x;
T* df_temp = (nbr_bits & 1) ? x : buffer_ptr.data();
for (int pass = nbr_bits - 1; pass >= 3; --pass) {
int nbr_coef = 1 << pass;
int h_nbr_coef = nbr_coef >> 1;
int d_nbr_coef = nbr_coef << 1;
const T* cos_ptr = trigo_lut.get_ptr(pass);
for (int i = 0; i < N; i += d_nbr_coef) {
T* sfr = sf + i;
T* sfi = sfr + nbr_coef;
T* df1r = df + i;
T* df2r = df1r + nbr_coef;
df1r[0] = sfr[0] + sfr[nbr_coef];
df2r[0] = sfr[0] - sfr[nbr_coef];
df1r[h_nbr_coef] = sfr[h_nbr_coef] * 2;
df2r[h_nbr_coef] = sfi[h_nbr_coef] * 2;
for (int j = 1; j < h_nbr_coef; j += 2) {
float32x2_t c = vld1_f32(cos_ptr + j);
float32x2_t s = vld1_f32(cos_ptr + h_nbr_coef - j - 1);
float32x2_t sfr_j = {sfr[j], sfr[j + 1]};
float32x2_t sfi_negj = {sfi[-j], sfi[-j - 1]};
float32x2_t sfi_j = {sfi[j], sfi[j + 1]};
float32x2_t sfi_nbrj = {sfi[nbr_coef - j], sfi[nbr_coef - j - 1]};
// Compute df1r[j] and df1r[j + h_nbr_coef]
float32x2_t df1r_j = vadd_f32(sfr_j, sfi_negj);
float32x2_t df1r_jh = vsub_f32(sfi_j, sfi_nbrj);
vst1_f32(df1r + j, df1r_j);
vst1_f32(df1r + j + h_nbr_coef, df1r_jh);
// Compute vr and vi
float32x2_t vr = vsub_f32(sfr_j, sfi_negj);
float32x2_t vi = vadd_f32(sfi_j, sfi_nbrj);
// Compute df2r[j] and df2r[j + h_nbr_coef]
float32x2_t df2r_j = vadd_f32(vmul_f32(vr, c), vmul_f32(vi, s));
float32x2_t df2r_jh = vsub_f32(vmul_f32(vi, c), vmul_f32(vr, s));
vst1_f32(df2r + j, df2r_j);
vst1_f32(df2r + j + h_nbr_coef, df2r_jh);
}
}
if (pass < nbr_bits - 1) std::swap(df, sf);
else { sf = df; df = df_temp; }
}
// Antepenultimate pass with NEON
for (int i = 0; i < N; i += 8) {
float32x4_t sf2 = vld1q_f32(sf + i);
float32x4_t df2 = vld1q_f32(df + i);
// Process using NEON intrinsics...
// (Implementation depends on specific operations, similar to do_fft)
}
// Final passes with NEON
const int* lut = bit_rev_lut.get_ptr();
for (int i = 0; i < N; i += 8) {
float32x4_t sf2 = vld1q_f32(df + i);
// Process using NEON intrinsics...
}
} else {
handle_special_cases_ifft(f, x, mul);
}
}
// Bit-reversed LUT
class BitReversedLUT {
public:
BitReversedLUT() : lut(N) {
int br_index = 0;
lut[0] = 0;
for (int cnt = 1; cnt < N; ++cnt) {
int bit = N >> 1;
while (((br_index ^= bit) & bit) == 0) bit >>= 1;
lut[cnt] = br_index;
}
}
const int* get_ptr() const { return lut.data(); }
private:
std::vector<int> lut;
};
// Trigonometric LUT
class TrigoLUT {
public:
TrigoLUT() {
if (nbr_bits > 3) {
int total_len = (1 << (nbr_bits - 1)) - 4;
lut.resize(total_len);
int pos = 0;
for (int level = 3; level < nbr_bits; ++level) {
int level_len = 1 << (level - 1);
T mul = M_PI / (level_len * 2);
for (int i = 0; i < level_len; ++i) {
lut[pos++] = std::cos(i * mul);
}
}
}
}
const T* get_ptr(int level) const {
return lut.data() + (1 << (level - 1)) - 4;
}
private:
std::vector<T> lut;
};
void handle_special_cases_fft(const T* x, T* f) {
if (nbr_bits == 2) {
f[1] = x[0] - x[2];
f[3] = x[1] - x[3];
T b_0 = x[0] + x[2], b_2 = x[1] + x[3];
f[0] = b_0 + b_2;
f[2] = b_0 - b_2;
} else if (nbr_bits == 1) {
f[0] = x[0] + x[1];
f[1] = x[0] - x[1];
} else {
f[0] = x[0];
}
}
void handle_special_cases_ifft(const T* f, T* x, T mul) {
if (nbr_bits == 2) {
T b_0 = f[0] + f[2], b_2 = f[0] - f[2];
x[0] = (b_0 + f[1] * 2) * mul;
x[2] = (b_0 - f[1] * 2) * mul;
x[1] = (b_2 + f[3] * 2) * mul;
x[3] = (b_2 - f[3] * 2) * mul;
} else if (nbr_bits == 1) {
x[0] = (f[0] + f[1]) * mul;
x[1] = (f[0] - f[1]) * mul;
} else {
x[0] = f[0] * mul;
}
}
BitReversedLUT bit_rev_lut;
TrigoLUT trigo_lut;
};
// Example usage
int main() {
FFTReal<float, 8> fft;
std::vector<float> input = {1, 2, 3, 4, 5, 6, 7, 8};
std::vector<Complex<float>> output(4);
fft.real_fft(input.data(), output.data());
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment