Skip to content

Instantly share code, notes, and snippets.

@youdie323323
Last active September 11, 2025 20:18
Show Gist options
  • Select an option

  • Save youdie323323/d49e76cdfff8b4d47ee917ed34252c6f to your computer and use it in GitHub Desktop.

Select an option

Save youdie323323/d49e76cdfff8b4d47ee917ed34252c6f to your computer and use it in GitHub Desktop.
The "Twiddless" Fast Fourier Transform Algorithm in zig
100000 iterations for every measurement
For N=4:
Total call time: 2772800 ns
Average time per call: 0.000 ms (27 ns)
For N=16:
Total call time: 11139000 ns
Average time per call: 0.000 ms (111 ns)
For N=64:
Total call time: 46906900 ns
Average time per call: 0.000 ms (469 ns)
For N=256:
Total call time: 224283200 ns
Average time per call: 0.002 ms (2242 ns)
For N=1024:
Total call time: 1074442100 ns
Average time per call: 0.011 ms (10744 ns)
const std = @import("std");
const math = std.math;
const complex = math.complex;
const Complex = math.Complex;
const debug = std.debug;
const time = std.time;
const heap = std.heap;
const mem = std.mem;
const Random = std.Random;
const posix = std.posix;
/// Performs Discrete Fourier transform.
/// Naive O(N^2) DFT, but capital_N is always <= 5,
/// thus uses DFT instead of FFT (since FFT recurively calls itself, so its slower than DFT in this context).
/// Caller must frees the arena.
fn dft(arena: mem.Allocator, comptime T: type, comptime capital_N: usize, x: []Complex(T)) mem.Allocator.Error![]Complex(T) {
@setEvalBranchQuota(100000);
const capital_N_T: T = comptime @floatFromInt(capital_N);
var capital_X = try arena.alloc(Complex(T), capital_N);
errdefer arena.free(capital_X);
inline for (0..capital_N) |i| {
capital_X[i] = comptime .init(0, 0);
}
inline for (0..capital_N) |k| {
inline for (0..capital_N) |n| {
const k_T: T = comptime @floatFromInt(k);
const n_T: T = comptime @floatFromInt(n);
const exp_arg: Complex(T) = comptime .init(
0,
-math.tau * k_T * n_T / capital_N_T,
);
const prod = x[n].mul(comptime complex.exp(exp_arg));
capital_X[k] = capital_X[k].add(prod);
}
}
return capital_X;
}
/// Recursive tfft-style algorithm with N/4 compression with
/// flexibility of mixed-radix input length N = c*4^k with the efficiency of radix-4.
/// Based on https://github.com/sauloqueiroz/tfft/blob/main/tfftradix4.m.
/// Caller must frees the arena.
fn tfftRadix4(arena: mem.Allocator, comptime T: type, comptime capital_N: usize, x: []Complex(T)) mem.Allocator.Error![]Complex(T) {
@setEvalBranchQuota(100000);
comptime if (capital_N % 4 != 0)
@compileError("N must be divisible by 4");
if (comptime (capital_N <= 5)) // Base case
return dft(arena, T, capital_N, x);
const capital_N_T: T = comptime @floatFromInt(capital_N);
// Length of sub-sequences
const capital_N4: usize = comptime (capital_N / 4);
var xhat_0_input = try arena.alloc(Complex(T), capital_N4);
errdefer arena.free(xhat_0_input);
var xhat_1_input_unscaled = try arena.alloc(Complex(T), capital_N4);
errdefer arena.free(xhat_1_input_unscaled);
var xhat_2_input_unscaled = try arena.alloc(Complex(T), capital_N4);
errdefer arena.free(xhat_2_input_unscaled);
var xhat_3_input_unscaled = try arena.alloc(Complex(T), capital_N4);
errdefer arena.free(xhat_3_input_unscaled);
inline for (0..capital_N4) |p| {
// Or direct indexing for clarity (p implicitly loops from 0 to N4-1 via vector ops)
const x_0p = x[p];
const x_1p = x[comptime (p + capital_N4)];
const x_2p = x[comptime (p + 2 * capital_N4)];
const x_3p = x[comptime (p + 3 * capital_N4)];
// Butterfly-like computations (sums/differences for radix-4 DIT). However,
// these are the inputs to the sub-DFTs *BEFORE* multiplication by W_N^kp factors. THERE IS NO COMBINATION AT ALL
const sum_1 = x_0p.add(x_2p);
const sum_2 = x_1p.add(x_3p);
const diff_1 = x_0p.sub(x_2p);
const diff_2 = x_1p.sub(x_3p);
xhat_0_input[p] = sum_1.add(sum_2);
const i_diff_2 = diff_2.mulbyi();
const minus_i_diff_2 = i_diff_2.neg();
// zig fmt: off
xhat_1_input_unscaled[p] = diff_1.add(minus_i_diff_2); // Corresponds to (x0 - jx1 - x2 + jx3)
xhat_2_input_unscaled[p] = sum_1.sub(sum_2); // Corresponds to (x0 - x1 + x2 - x3)
xhat_3_input_unscaled[p] = diff_1.add(i_diff_2); // Corresponds to (x0 + jx1 - x2 - jx3)
// zig fmt: on
}
var xhat_1 = try arena.alloc(Complex(T), capital_N4);
errdefer arena.free(xhat_1);
var xhat_2 = try arena.alloc(Complex(T), capital_N4);
errdefer arena.free(xhat_2);
var xhat_3 = try arena.alloc(Complex(T), capital_N4);
errdefer arena.free(xhat_3);
inline for (0..capital_N4) |p| {
const p_T: T = comptime @floatFromInt(p);
const minus_tau_p_T_capital_N_T = comptime (-math.tau * p_T / capital_N_T);
// zig fmt: off
const twiddle_factors_p1_exp_arg: Complex(T) = comptime .init(0, minus_tau_p_T_capital_N_T); // W_N^p terms
const twiddle_factors_p2_exp_arg: Complex(T) = comptime .init(0, 2 * minus_tau_p_T_capital_N_T); // W_N^2p terms
const twiddle_factors_p3_exp_arg: Complex(T) = comptime .init(0, 3 * minus_tau_p_T_capital_N_T); // W_N^3p terms
// zig fmt: on
// Apply twiddle factors to get the inputs for recursive calls
// xhat0 is not multiplied by a twiddle factor (or W_N^0p = 1)
xhat_1[p] = xhat_1_input_unscaled[p].mul(comptime complex.exp(twiddle_factors_p1_exp_arg));
xhat_2[p] = xhat_2_input_unscaled[p].mul(comptime complex.exp(twiddle_factors_p2_exp_arg));
xhat_3[p] = xhat_3_input_unscaled[p].mul(comptime complex.exp(twiddle_factors_p3_exp_arg));
}
// Recursive calls for N/4 length DFTs
const capital_Y_0 = try tfftRadix4(arena, T, capital_N4, xhat_0_input);
const capital_Y_1 = try tfftRadix4(arena, T, capital_N4, xhat_1);
const capital_Y_2 = try tfftRadix4(arena, T, capital_N4, xhat_2);
const capital_Y_3 = try tfftRadix4(arena, T, capital_N4, xhat_3);
var capital_X = try arena.alloc(Complex(T), capital_N);
errdefer arena.free(capital_X);
inline for (0..capital_N4) |m| { // Combine results (interleaving for radix-4)
const four_times_m = comptime (4 * m);
// zig fmt: off
capital_X[four_times_m ] = capital_Y_0[m]; // X_k for k = 4m
capital_X[comptime (four_times_m + 1)] = capital_Y_1[m]; // X_k for k = 4m + 1
capital_X[comptime (four_times_m + 2)] = capital_Y_2[m]; // X_k for k = 4m + 2
capital_X[comptime (four_times_m + 3)] = capital_Y_3[m]; // X_k for k = 4m + 3
// zig fmt: on
}
return capital_X;
}
pub fn main() !void {
var prng: Random.DefaultPrng = .init(blk: {
var seed: u64 = undefined;
try posix.getrandom(mem.asBytes(&seed));
break :blk seed;
});
const rand = prng.random();
var arena_allocator: heap.ArenaAllocator = .init(heap.c_allocator);
defer arena_allocator.deinit();
const arena = arena_allocator.allocator();
const iterations: usize = 100000;
debug.print("{d} iterations for every measurement\n", .{iterations});
inline for (1..6) |k| {
const capital_N = comptime math.pow(usize, 4, k);
var timer: time.Timer = try .start();
var total_time_ns: u64 = 0;
for (0..iterations) |_| {
const x = try arena.alloc(Complex(f32), capital_N);
for (x) |*c| {
c.* = .init(
rand.float(f32),
rand.float(f32),
);
}
{
timer.reset();
_ = try tfftRadix4(arena, f32, capital_N, x);
total_time_ns += timer.read();
}
_ = arena_allocator.reset(.retain_capacity);
}
// Free all memories for next measurement
// _ = arena_allocator.reset(.free_all);
const avg_time_ns = total_time_ns / iterations;
const avg_time_ms = @as(f64, @floatFromInt(avg_time_ns)) / @as(f64, 1_000_000);
debug.print("\n", .{});
debug.print("For N={d}:\n", .{capital_N});
debug.print("Total call time: {d} ns\n", .{total_time_ns});
debug.print("Average time per call: {d:.3} ms ({d} ns)\n", .{ avg_time_ms, avg_time_ns });
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment