Last active
September 11, 2025 20:18
-
-
Save youdie323323/d49e76cdfff8b4d47ee917ed34252c6f to your computer and use it in GitHub Desktop.
The "Twiddless" Fast Fourier Transform Algorithm in zig
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| 100000 iterations for every measurement | |
| For N=4: | |
| Total call time: 2772800 ns | |
| Average time per call: 0.000 ms (27 ns) | |
| For N=16: | |
| Total call time: 11139000 ns | |
| Average time per call: 0.000 ms (111 ns) | |
| For N=64: | |
| Total call time: 46906900 ns | |
| Average time per call: 0.000 ms (469 ns) | |
| For N=256: | |
| Total call time: 224283200 ns | |
| Average time per call: 0.002 ms (2242 ns) | |
| For N=1024: | |
| Total call time: 1074442100 ns | |
| Average time per call: 0.011 ms (10744 ns) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| const std = @import("std"); | |
| const math = std.math; | |
| const complex = math.complex; | |
| const Complex = math.Complex; | |
| const debug = std.debug; | |
| const time = std.time; | |
| const heap = std.heap; | |
| const mem = std.mem; | |
| const Random = std.Random; | |
| const posix = std.posix; | |
| /// Performs Discrete Fourier transform. | |
| /// Naive O(N^2) DFT, but capital_N is always <= 5, | |
| /// thus uses DFT instead of FFT (since FFT recurively calls itself, so its slower than DFT in this context). | |
| /// Caller must frees the arena. | |
| fn dft(arena: mem.Allocator, comptime T: type, comptime capital_N: usize, x: []Complex(T)) mem.Allocator.Error![]Complex(T) { | |
| @setEvalBranchQuota(100000); | |
| const capital_N_T: T = comptime @floatFromInt(capital_N); | |
| var capital_X = try arena.alloc(Complex(T), capital_N); | |
| errdefer arena.free(capital_X); | |
| inline for (0..capital_N) |i| { | |
| capital_X[i] = comptime .init(0, 0); | |
| } | |
| inline for (0..capital_N) |k| { | |
| inline for (0..capital_N) |n| { | |
| const k_T: T = comptime @floatFromInt(k); | |
| const n_T: T = comptime @floatFromInt(n); | |
| const exp_arg: Complex(T) = comptime .init( | |
| 0, | |
| -math.tau * k_T * n_T / capital_N_T, | |
| ); | |
| const prod = x[n].mul(comptime complex.exp(exp_arg)); | |
| capital_X[k] = capital_X[k].add(prod); | |
| } | |
| } | |
| return capital_X; | |
| } | |
| /// Recursive tfft-style algorithm with N/4 compression with | |
| /// flexibility of mixed-radix input length N = c*4^k with the efficiency of radix-4. | |
| /// Based on https://github.com/sauloqueiroz/tfft/blob/main/tfftradix4.m. | |
| /// Caller must frees the arena. | |
| fn tfftRadix4(arena: mem.Allocator, comptime T: type, comptime capital_N: usize, x: []Complex(T)) mem.Allocator.Error![]Complex(T) { | |
| @setEvalBranchQuota(100000); | |
| comptime if (capital_N % 4 != 0) | |
| @compileError("N must be divisible by 4"); | |
| if (comptime (capital_N <= 5)) // Base case | |
| return dft(arena, T, capital_N, x); | |
| const capital_N_T: T = comptime @floatFromInt(capital_N); | |
| // Length of sub-sequences | |
| const capital_N4: usize = comptime (capital_N / 4); | |
| var xhat_0_input = try arena.alloc(Complex(T), capital_N4); | |
| errdefer arena.free(xhat_0_input); | |
| var xhat_1_input_unscaled = try arena.alloc(Complex(T), capital_N4); | |
| errdefer arena.free(xhat_1_input_unscaled); | |
| var xhat_2_input_unscaled = try arena.alloc(Complex(T), capital_N4); | |
| errdefer arena.free(xhat_2_input_unscaled); | |
| var xhat_3_input_unscaled = try arena.alloc(Complex(T), capital_N4); | |
| errdefer arena.free(xhat_3_input_unscaled); | |
| inline for (0..capital_N4) |p| { | |
| // Or direct indexing for clarity (p implicitly loops from 0 to N4-1 via vector ops) | |
| const x_0p = x[p]; | |
| const x_1p = x[comptime (p + capital_N4)]; | |
| const x_2p = x[comptime (p + 2 * capital_N4)]; | |
| const x_3p = x[comptime (p + 3 * capital_N4)]; | |
| // Butterfly-like computations (sums/differences for radix-4 DIT). However, | |
| // these are the inputs to the sub-DFTs *BEFORE* multiplication by W_N^kp factors. THERE IS NO COMBINATION AT ALL | |
| const sum_1 = x_0p.add(x_2p); | |
| const sum_2 = x_1p.add(x_3p); | |
| const diff_1 = x_0p.sub(x_2p); | |
| const diff_2 = x_1p.sub(x_3p); | |
| xhat_0_input[p] = sum_1.add(sum_2); | |
| const i_diff_2 = diff_2.mulbyi(); | |
| const minus_i_diff_2 = i_diff_2.neg(); | |
| // zig fmt: off | |
| xhat_1_input_unscaled[p] = diff_1.add(minus_i_diff_2); // Corresponds to (x0 - jx1 - x2 + jx3) | |
| xhat_2_input_unscaled[p] = sum_1.sub(sum_2); // Corresponds to (x0 - x1 + x2 - x3) | |
| xhat_3_input_unscaled[p] = diff_1.add(i_diff_2); // Corresponds to (x0 + jx1 - x2 - jx3) | |
| // zig fmt: on | |
| } | |
| var xhat_1 = try arena.alloc(Complex(T), capital_N4); | |
| errdefer arena.free(xhat_1); | |
| var xhat_2 = try arena.alloc(Complex(T), capital_N4); | |
| errdefer arena.free(xhat_2); | |
| var xhat_3 = try arena.alloc(Complex(T), capital_N4); | |
| errdefer arena.free(xhat_3); | |
| inline for (0..capital_N4) |p| { | |
| const p_T: T = comptime @floatFromInt(p); | |
| const minus_tau_p_T_capital_N_T = comptime (-math.tau * p_T / capital_N_T); | |
| // zig fmt: off | |
| const twiddle_factors_p1_exp_arg: Complex(T) = comptime .init(0, minus_tau_p_T_capital_N_T); // W_N^p terms | |
| const twiddle_factors_p2_exp_arg: Complex(T) = comptime .init(0, 2 * minus_tau_p_T_capital_N_T); // W_N^2p terms | |
| const twiddle_factors_p3_exp_arg: Complex(T) = comptime .init(0, 3 * minus_tau_p_T_capital_N_T); // W_N^3p terms | |
| // zig fmt: on | |
| // Apply twiddle factors to get the inputs for recursive calls | |
| // xhat0 is not multiplied by a twiddle factor (or W_N^0p = 1) | |
| xhat_1[p] = xhat_1_input_unscaled[p].mul(comptime complex.exp(twiddle_factors_p1_exp_arg)); | |
| xhat_2[p] = xhat_2_input_unscaled[p].mul(comptime complex.exp(twiddle_factors_p2_exp_arg)); | |
| xhat_3[p] = xhat_3_input_unscaled[p].mul(comptime complex.exp(twiddle_factors_p3_exp_arg)); | |
| } | |
| // Recursive calls for N/4 length DFTs | |
| const capital_Y_0 = try tfftRadix4(arena, T, capital_N4, xhat_0_input); | |
| const capital_Y_1 = try tfftRadix4(arena, T, capital_N4, xhat_1); | |
| const capital_Y_2 = try tfftRadix4(arena, T, capital_N4, xhat_2); | |
| const capital_Y_3 = try tfftRadix4(arena, T, capital_N4, xhat_3); | |
| var capital_X = try arena.alloc(Complex(T), capital_N); | |
| errdefer arena.free(capital_X); | |
| inline for (0..capital_N4) |m| { // Combine results (interleaving for radix-4) | |
| const four_times_m = comptime (4 * m); | |
| // zig fmt: off | |
| capital_X[four_times_m ] = capital_Y_0[m]; // X_k for k = 4m | |
| capital_X[comptime (four_times_m + 1)] = capital_Y_1[m]; // X_k for k = 4m + 1 | |
| capital_X[comptime (four_times_m + 2)] = capital_Y_2[m]; // X_k for k = 4m + 2 | |
| capital_X[comptime (four_times_m + 3)] = capital_Y_3[m]; // X_k for k = 4m + 3 | |
| // zig fmt: on | |
| } | |
| return capital_X; | |
| } | |
| pub fn main() !void { | |
| var prng: Random.DefaultPrng = .init(blk: { | |
| var seed: u64 = undefined; | |
| try posix.getrandom(mem.asBytes(&seed)); | |
| break :blk seed; | |
| }); | |
| const rand = prng.random(); | |
| var arena_allocator: heap.ArenaAllocator = .init(heap.c_allocator); | |
| defer arena_allocator.deinit(); | |
| const arena = arena_allocator.allocator(); | |
| const iterations: usize = 100000; | |
| debug.print("{d} iterations for every measurement\n", .{iterations}); | |
| inline for (1..6) |k| { | |
| const capital_N = comptime math.pow(usize, 4, k); | |
| var timer: time.Timer = try .start(); | |
| var total_time_ns: u64 = 0; | |
| for (0..iterations) |_| { | |
| const x = try arena.alloc(Complex(f32), capital_N); | |
| for (x) |*c| { | |
| c.* = .init( | |
| rand.float(f32), | |
| rand.float(f32), | |
| ); | |
| } | |
| { | |
| timer.reset(); | |
| _ = try tfftRadix4(arena, f32, capital_N, x); | |
| total_time_ns += timer.read(); | |
| } | |
| _ = arena_allocator.reset(.retain_capacity); | |
| } | |
| // Free all memories for next measurement | |
| // _ = arena_allocator.reset(.free_all); | |
| const avg_time_ns = total_time_ns / iterations; | |
| const avg_time_ms = @as(f64, @floatFromInt(avg_time_ns)) / @as(f64, 1_000_000); | |
| debug.print("\n", .{}); | |
| debug.print("For N={d}:\n", .{capital_N}); | |
| debug.print("Total call time: {d} ns\n", .{total_time_ns}); | |
| debug.print("Average time per call: {d:.3} ms ({d} ns)\n", .{ avg_time_ms, avg_time_ns }); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment