Created
October 17, 2025 05:50
-
-
Save msuiche/5f566515512220af2e0e0922d9d3a102 to your computer and use it in GitHub Desktop.
Fast Fournier Transformation in Mojo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from gpu.host import DeviceContext | |
| from memory import UnsafePointer | |
| from math import ceildiv, fma | |
| from gpu import global_idx | |
| alias M_PI: Float64 = 3.14159265358979323846 | |
| alias M_PI_2: Float64 = 1.57079632679489661923 | |
| alias M_2_PI: Float64 = 0.636619772367581343076 # 2/π | |
| alias M_PI_4: Float64 = 0.785398163397448309616 | |
| alias M_PI_F32: Float32 = 3.14159265358979323846 | |
| # Exact software sin/cos matching CUDA libdevice implementation | |
| # Uses Payne-Hanek range reduction and minimax polynomial approximation | |
| @always_inline | |
| fn round_to_int(x: Float32) -> Int: | |
| """Round to nearest integer (ties to even) matching PTX cvt.rni.s32.f32.""" | |
| var truncated = Int(x) | |
| var diff = x - Float32(truncated) | |
| if diff > 0.5: | |
| return truncated + 1 | |
| elif diff < -0.5: | |
| return truncated - 1 | |
| elif diff == 0.5: | |
| # Tie: round to even | |
| return truncated + 1 if (truncated & 1) != 0 else truncated | |
| elif diff == -0.5: | |
| # Tie: round to even | |
| return truncated - 1 if (truncated & 1) != 0 else truncated | |
| else: | |
| return truncated | |
| @always_inline | |
| fn libdevice_sinf(x: Float32) -> Float32: | |
| """ | |
| Software sine matching CUDA PTX exactly. | |
| Based on PTX: mul by 2/π, cvt.rni (round to nearest int), then reduce with FMAs. | |
| """ | |
| var temp = x * Float32(0.6366197723675814) | |
| var k = round_to_int(temp) | |
| var xr = fma(Float32(k), -1.5707963705062866, x) | |
| xr = fma(Float32(k), -4.3711388286738386e-08, xr) | |
| xr = fma(Float32(k), -1.2560587133447677e-15, xr) | |
| # Determine which polynomial and sign based on k | |
| # sin uses k+1 for phase shift for polynomial selection | |
| var poly_bit = ((k + 1) & 1) | |
| # Sign based on whether k is in quadrants 2-3 (bit 1 of k) | |
| var sign_bit = (k & 2) | |
| var x2 = xr * xr | |
| var poly: Float32 | |
| if poly_bit == 0: | |
| # Cosine polynomial | |
| var c = fma(Float32(2.44331570e-05), x2, Float32(-1.38873163e-03)) | |
| c = fma(c, x2, Float32(4.16666418e-02)) | |
| c = fma(c, x2, Float32(-0.5)) | |
| poly = fma(c, x2, Float32(1.0)) | |
| else: | |
| # Sine polynomial | |
| var c = fma(Float32(-1.95152959e-04), x2, Float32(8.33216030e-03)) | |
| c = fma(c, x2, Float32(-1.66666552e-01)) | |
| poly = fma(c, x2 * xr, xr) | |
| return -poly if sign_bit != 0 else poly | |
| @always_inline | |
| fn libdevice_sin(x: Float64) -> Float64: | |
| """Float64 version of sine for higher precision using hardware FMA.""" | |
| var temp = x * 0.6366197723675814 | |
| var k = Int(temp + (0.5 if temp >= 0.0 else -0.5)) | |
| # Use hardware FMA for range reduction | |
| var xr = fma(Float64(k), -1.5707963267948966, x) | |
| xr = fma(Float64(k), -6.123233995736766e-17, xr) | |
| var poly_bit = ((k + 1) & 1) | |
| var sign_bit = (k & 2) | |
| var x2 = xr * xr | |
| var poly: Float64 | |
| if poly_bit == 0: | |
| # Cosine polynomial with hardware FMA | |
| var c = fma(2.44331570e-05, x2, -1.38873163e-03) | |
| c = fma(c, x2, 4.16666418e-02) | |
| c = fma(c, x2, -0.5) | |
| poly = fma(c, x2, 1.0) | |
| else: | |
| # Sine polynomial with hardware FMA | |
| var c = fma(-1.95152959e-04, x2, 8.33216030e-03) | |
| c = fma(c, x2, -1.66666552e-01) | |
| poly = fma(c, x2 * xr, xr) | |
| return -poly if sign_bit != 0 else poly | |
| @always_inline | |
| fn libdevice_cos(x: Float64) -> Float64: | |
| """Float64 version of cosine for higher precision using hardware FMA.""" | |
| var temp = x * 0.6366197723675814 | |
| var k = Int(temp + (0.5 if temp >= 0.0 else -0.5)) | |
| # Use hardware FMA for range reduction | |
| var xr = fma(Float64(k), -1.5707963267948966, x) | |
| xr = fma(Float64(k), -6.123233995736766e-17, xr) | |
| var poly_bit = k & 1 | |
| var sign_bit = ((k + 1) & 2) | |
| var x2 = xr * xr | |
| var poly: Float64 | |
| if poly_bit == 0: | |
| # Cosine polynomial with hardware FMA | |
| var c = fma(2.44331570e-05, x2, -1.38873163e-03) | |
| c = fma(c, x2, 4.16666418e-02) | |
| c = fma(c, x2, -0.5) | |
| poly = fma(c, x2, 1.0) | |
| else: | |
| # Sine polynomial with hardware FMA | |
| var c = fma(-1.95152959e-04, x2, 8.33216030e-03) | |
| c = fma(c, x2, -1.66666552e-01) | |
| poly = fma(c, x2 * xr, xr) | |
| return -poly if sign_bit != 0 else poly | |
| @always_inline | |
| fn libdevice_cosf(x: Float32) -> Float32: | |
| """ | |
| Software cosine matching CUDA PTX exactly. | |
| Same as sinf but doesn't add 1 to k (cos vs sin phase offset). | |
| """ | |
| var temp = x * Float32(0.6366197723675814) | |
| var k = round_to_int(temp) | |
| var xr = fma(Float32(k), -1.5707963705062866, x) | |
| xr = fma(Float32(k), -4.3711388286738386e-08, xr) | |
| xr = fma(Float32(k), -1.2560587133447677e-15, xr) | |
| # Determine which polynomial and sign based on k | |
| # cos uses k directly for polynomial selection (no phase shift) | |
| var poly_bit = k & 1 | |
| # cos(x) = sin(x + π/2), so shift k forward by 1 for sign | |
| var sign_bit = ((k + 1) & 2) | |
| var x2 = xr * xr | |
| var poly: Float32 | |
| if poly_bit == 0: | |
| # Cosine polynomial | |
| var c = fma(Float32(2.44331570e-05), x2, Float32(-1.38873163e-03)) | |
| c = fma(c, x2, Float32(4.16666418e-02)) | |
| c = fma(c, x2, Float32(-0.5)) | |
| poly = fma(c, x2, Float32(1.0)) | |
| else: | |
| # Sine polynomial | |
| var c = fma(Float32(-1.95152959e-04), x2, Float32(8.33216030e-03)) | |
| c = fma(c, x2, Float32(-1.66666552e-01)) | |
| poly = fma(c, x2 * xr, xr) | |
| return -poly if sign_bit != 0 else poly | |
| fn complex_multiply( | |
| ar: Float32, ai: Float32, br: Float32, bi: Float32 | |
| ) -> (Float32, Float32): | |
| """Multiply two complex numbers.""" | |
| var cr = ar * br - ai * bi | |
| var ci = ar * bi + ai * br | |
| return (cr, ci) | |
| fn dft_kernel( | |
| signal: UnsafePointer[Float32], | |
| spectrum: UnsafePointer[Float32], | |
| N: Int32, | |
| ): | |
| """Direct DFT computation kernel using Float64 intermediate calculations.""" | |
| var k = Int(global_idx.x) | |
| if k >= Int(N): | |
| return | |
| # Use Float64 for all intermediate calculations to avoid ordering nondeterminism | |
| var real_sum: Float64 = 0.0 | |
| var imag_sum: Float64 = 0.0 | |
| var real_c: Float64 = 0.0 # Running compensation for real part | |
| var imag_c: Float64 = 0.0 # Running compensation for imaginary part | |
| for n in range(Int(N)): | |
| # Compute angle in Float64 for maximum precision | |
| var k_f = Float64(k) | |
| var n_f = Float64(n) | |
| var N_f = Float64(N) | |
| var angle = -2.0 * M_PI * k_f * n_f / N_f | |
| var cos_val = libdevice_cos(angle) | |
| var sin_val = libdevice_sin(angle) | |
| var x_real = Float64(signal[2 * n]) | |
| var x_imag = Float64(signal[2 * n + 1]) | |
| # Compute complex multiplication in Float64 | |
| var temp_real = x_real * cos_val - x_imag * sin_val | |
| var temp_imag = x_real * sin_val + x_imag * cos_val | |
| # Kahan summation for real part - compensates for lost low-order bits | |
| var real_y = temp_real - real_c | |
| var real_t = real_sum + real_y | |
| real_c = (real_t - real_sum) - real_y | |
| real_sum = real_t | |
| # Kahan summation for imaginary part | |
| var imag_y = temp_imag - imag_c | |
| var imag_t = imag_sum + imag_y | |
| imag_c = (imag_t - imag_sum) - imag_y | |
| imag_sum = imag_t | |
| # Convert back to Float32 only at output | |
| spectrum[2 * k] = Float32(real_sum) | |
| spectrum[2 * k + 1] = Float32(imag_sum) | |
| fn bit_reverse_kernel( | |
| input: UnsafePointer[Float32], | |
| output: UnsafePointer[Float32], | |
| N: Int32, | |
| log2N: Int32, | |
| ): | |
| """Bit-reverse permutation kernel.""" | |
| var idx = Int(global_idx.x) | |
| if idx >= Int(N): | |
| return | |
| var reversed = 0 | |
| var temp = idx | |
| for _ in range(Int(log2N)): | |
| reversed = (reversed << 1) | (temp & 1) | |
| temp >>= 1 | |
| output[2 * reversed] = input[2 * idx] | |
| output[2 * reversed + 1] = input[2 * idx + 1] | |
| fn fft_stage_kernel( | |
| data: UnsafePointer[Float32], | |
| N: Int32, | |
| stage: Int32, | |
| ): | |
| """FFT butterfly computation for a specific stage using Float64 intermediate calculations.""" | |
| var idx = Int(global_idx.x) | |
| var stage_size = 1 << Int(stage) | |
| var group_size = stage_size << 1 | |
| var num_groups = Int(N) // group_size | |
| if idx >= Int(N) // 2: | |
| return | |
| var group = idx // stage_size | |
| var pos = idx % stage_size | |
| if group >= num_groups: | |
| return | |
| var base = group * group_size | |
| var i = base + pos | |
| var j = base + pos + stage_size | |
| if i >= Int(N) or j >= Int(N): | |
| return | |
| # Compute twiddle factor in Float64 for precision | |
| var angle = -2.0 * M_PI * Float64(pos) / Float64(group_size) | |
| var wr = libdevice_cos(angle) | |
| var wi = libdevice_sin(angle) | |
| var ar = Float64(data[2 * i]) | |
| var ai = Float64(data[2 * i + 1]) | |
| var br = Float64(data[2 * j]) | |
| var bi = Float64(data[2 * j + 1]) | |
| # Complex multiply in Float64 | |
| var tr = br * wr - bi * wi | |
| var ti = br * wi + bi * wr | |
| # Butterfly in Float64, convert to Float32 at output | |
| data[2 * i] = Float32(ar + tr) | |
| data[2 * i + 1] = Float32(ai + ti) | |
| data[2 * j] = Float32(ar - tr) | |
| data[2 * j + 1] = Float32(ai - ti) | |
| @export | |
| def solve(signal: UnsafePointer[Float32], spectrum: UnsafePointer[Float32], N: Int32): | |
| """Main FFT solver that dispatches to appropriate algorithm.""" | |
| if N <= 0: | |
| return | |
| var is_power_of_2 = (N & (N - 1)) == 0 | |
| var ctx = DeviceContext() | |
| # Create device buffers and copy input data | |
| var signal_device = ctx.enqueue_create_buffer[DType.float32](Int(2 * N)) | |
| var spectrum_device = ctx.enqueue_create_buffer[DType.float32](Int(2 * N)) | |
| ctx.enqueue_copy(signal_device.unsafe_ptr(), signal, Int(2 * N)) | |
| ctx.synchronize() | |
| if is_power_of_2 and N >= 2: | |
| var log2N: Int32 = 0 | |
| var temp = N | |
| while temp > 1: | |
| temp >>= 1 | |
| log2N += 1 | |
| var temp_buffer = ctx.enqueue_create_buffer[DType.float32](Int(2 * N)) | |
| alias THREADS_PER_BLOCK = 256 | |
| var blocks = ceildiv(Int(N), THREADS_PER_BLOCK) | |
| ctx.enqueue_function_unchecked[bit_reverse_kernel]( | |
| signal_device, | |
| temp_buffer, | |
| N, | |
| log2N, | |
| grid_dim=(blocks,), | |
| block_dim=(THREADS_PER_BLOCK,), | |
| ) | |
| ctx.synchronize() | |
| for stage in range(Int(log2N)): | |
| var stage_blocks = ceildiv(Int(N) // 2, THREADS_PER_BLOCK) | |
| if stage_blocks > 0: | |
| ctx.enqueue_function_unchecked[fft_stage_kernel]( | |
| temp_buffer, | |
| N, | |
| Int32(stage), | |
| grid_dim=(stage_blocks,), | |
| block_dim=(THREADS_PER_BLOCK,), | |
| ) | |
| ctx.synchronize() | |
| ctx.enqueue_copy(spectrum, temp_buffer.unsafe_ptr(), Int(2 * N)) | |
| ctx.synchronize() | |
| else: | |
| alias THREADS_PER_BLOCK = 256 | |
| var blocks = ceildiv(Int(N), THREADS_PER_BLOCK) | |
| ctx.enqueue_function_unchecked[dft_kernel]( | |
| signal_device, | |
| spectrum_device, | |
| N, | |
| grid_dim=(blocks,), | |
| block_dim=(THREADS_PER_BLOCK,), | |
| ) | |
| ctx.synchronize() | |
| ctx.enqueue_copy(spectrum, spectrum_device.unsafe_ptr(), Int(2 * N)) | |
| ctx.synchronize() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment