Skip to content

Instantly share code, notes, and snippets.

@msuiche
Created October 17, 2025 05:50
Show Gist options
  • Select an option

  • Save msuiche/5f566515512220af2e0e0922d9d3a102 to your computer and use it in GitHub Desktop.

Select an option

Save msuiche/5f566515512220af2e0e0922d9d3a102 to your computer and use it in GitHub Desktop.
Fast Fournier Transformation in Mojo
from gpu.host import DeviceContext
from memory import UnsafePointer
from math import ceildiv, fma
from gpu import global_idx
alias M_PI: Float64 = 3.14159265358979323846
alias M_PI_2: Float64 = 1.57079632679489661923
alias M_2_PI: Float64 = 0.636619772367581343076 # 2/π
alias M_PI_4: Float64 = 0.785398163397448309616
alias M_PI_F32: Float32 = 3.14159265358979323846
# Exact software sin/cos matching CUDA libdevice implementation
# Uses Payne-Hanek range reduction and minimax polynomial approximation
@always_inline
fn round_to_int(x: Float32) -> Int:
"""Round to nearest integer (ties to even) matching PTX cvt.rni.s32.f32."""
var truncated = Int(x)
var diff = x - Float32(truncated)
if diff > 0.5:
return truncated + 1
elif diff < -0.5:
return truncated - 1
elif diff == 0.5:
# Tie: round to even
return truncated + 1 if (truncated & 1) != 0 else truncated
elif diff == -0.5:
# Tie: round to even
return truncated - 1 if (truncated & 1) != 0 else truncated
else:
return truncated
@always_inline
fn libdevice_sinf(x: Float32) -> Float32:
"""
Software sine matching CUDA PTX exactly.
Based on PTX: mul by 2/π, cvt.rni (round to nearest int), then reduce with FMAs.
"""
var temp = x * Float32(0.6366197723675814)
var k = round_to_int(temp)
var xr = fma(Float32(k), -1.5707963705062866, x)
xr = fma(Float32(k), -4.3711388286738386e-08, xr)
xr = fma(Float32(k), -1.2560587133447677e-15, xr)
# Determine which polynomial and sign based on k
# sin uses k+1 for phase shift for polynomial selection
var poly_bit = ((k + 1) & 1)
# Sign based on whether k is in quadrants 2-3 (bit 1 of k)
var sign_bit = (k & 2)
var x2 = xr * xr
var poly: Float32
if poly_bit == 0:
# Cosine polynomial
var c = fma(Float32(2.44331570e-05), x2, Float32(-1.38873163e-03))
c = fma(c, x2, Float32(4.16666418e-02))
c = fma(c, x2, Float32(-0.5))
poly = fma(c, x2, Float32(1.0))
else:
# Sine polynomial
var c = fma(Float32(-1.95152959e-04), x2, Float32(8.33216030e-03))
c = fma(c, x2, Float32(-1.66666552e-01))
poly = fma(c, x2 * xr, xr)
return -poly if sign_bit != 0 else poly
@always_inline
fn libdevice_sin(x: Float64) -> Float64:
"""Float64 version of sine for higher precision using hardware FMA."""
var temp = x * 0.6366197723675814
var k = Int(temp + (0.5 if temp >= 0.0 else -0.5))
# Use hardware FMA for range reduction
var xr = fma(Float64(k), -1.5707963267948966, x)
xr = fma(Float64(k), -6.123233995736766e-17, xr)
var poly_bit = ((k + 1) & 1)
var sign_bit = (k & 2)
var x2 = xr * xr
var poly: Float64
if poly_bit == 0:
# Cosine polynomial with hardware FMA
var c = fma(2.44331570e-05, x2, -1.38873163e-03)
c = fma(c, x2, 4.16666418e-02)
c = fma(c, x2, -0.5)
poly = fma(c, x2, 1.0)
else:
# Sine polynomial with hardware FMA
var c = fma(-1.95152959e-04, x2, 8.33216030e-03)
c = fma(c, x2, -1.66666552e-01)
poly = fma(c, x2 * xr, xr)
return -poly if sign_bit != 0 else poly
@always_inline
fn libdevice_cos(x: Float64) -> Float64:
"""Float64 version of cosine for higher precision using hardware FMA."""
var temp = x * 0.6366197723675814
var k = Int(temp + (0.5 if temp >= 0.0 else -0.5))
# Use hardware FMA for range reduction
var xr = fma(Float64(k), -1.5707963267948966, x)
xr = fma(Float64(k), -6.123233995736766e-17, xr)
var poly_bit = k & 1
var sign_bit = ((k + 1) & 2)
var x2 = xr * xr
var poly: Float64
if poly_bit == 0:
# Cosine polynomial with hardware FMA
var c = fma(2.44331570e-05, x2, -1.38873163e-03)
c = fma(c, x2, 4.16666418e-02)
c = fma(c, x2, -0.5)
poly = fma(c, x2, 1.0)
else:
# Sine polynomial with hardware FMA
var c = fma(-1.95152959e-04, x2, 8.33216030e-03)
c = fma(c, x2, -1.66666552e-01)
poly = fma(c, x2 * xr, xr)
return -poly if sign_bit != 0 else poly
@always_inline
fn libdevice_cosf(x: Float32) -> Float32:
"""
Software cosine matching CUDA PTX exactly.
Same as sinf but doesn't add 1 to k (cos vs sin phase offset).
"""
var temp = x * Float32(0.6366197723675814)
var k = round_to_int(temp)
var xr = fma(Float32(k), -1.5707963705062866, x)
xr = fma(Float32(k), -4.3711388286738386e-08, xr)
xr = fma(Float32(k), -1.2560587133447677e-15, xr)
# Determine which polynomial and sign based on k
# cos uses k directly for polynomial selection (no phase shift)
var poly_bit = k & 1
# cos(x) = sin(x + π/2), so shift k forward by 1 for sign
var sign_bit = ((k + 1) & 2)
var x2 = xr * xr
var poly: Float32
if poly_bit == 0:
# Cosine polynomial
var c = fma(Float32(2.44331570e-05), x2, Float32(-1.38873163e-03))
c = fma(c, x2, Float32(4.16666418e-02))
c = fma(c, x2, Float32(-0.5))
poly = fma(c, x2, Float32(1.0))
else:
# Sine polynomial
var c = fma(Float32(-1.95152959e-04), x2, Float32(8.33216030e-03))
c = fma(c, x2, Float32(-1.66666552e-01))
poly = fma(c, x2 * xr, xr)
return -poly if sign_bit != 0 else poly
fn complex_multiply(
ar: Float32, ai: Float32, br: Float32, bi: Float32
) -> (Float32, Float32):
"""Multiply two complex numbers."""
var cr = ar * br - ai * bi
var ci = ar * bi + ai * br
return (cr, ci)
fn dft_kernel(
signal: UnsafePointer[Float32],
spectrum: UnsafePointer[Float32],
N: Int32,
):
"""Direct DFT computation kernel using Float64 intermediate calculations."""
var k = Int(global_idx.x)
if k >= Int(N):
return
# Use Float64 for all intermediate calculations to avoid ordering nondeterminism
var real_sum: Float64 = 0.0
var imag_sum: Float64 = 0.0
var real_c: Float64 = 0.0 # Running compensation for real part
var imag_c: Float64 = 0.0 # Running compensation for imaginary part
for n in range(Int(N)):
# Compute angle in Float64 for maximum precision
var k_f = Float64(k)
var n_f = Float64(n)
var N_f = Float64(N)
var angle = -2.0 * M_PI * k_f * n_f / N_f
var cos_val = libdevice_cos(angle)
var sin_val = libdevice_sin(angle)
var x_real = Float64(signal[2 * n])
var x_imag = Float64(signal[2 * n + 1])
# Compute complex multiplication in Float64
var temp_real = x_real * cos_val - x_imag * sin_val
var temp_imag = x_real * sin_val + x_imag * cos_val
# Kahan summation for real part - compensates for lost low-order bits
var real_y = temp_real - real_c
var real_t = real_sum + real_y
real_c = (real_t - real_sum) - real_y
real_sum = real_t
# Kahan summation for imaginary part
var imag_y = temp_imag - imag_c
var imag_t = imag_sum + imag_y
imag_c = (imag_t - imag_sum) - imag_y
imag_sum = imag_t
# Convert back to Float32 only at output
spectrum[2 * k] = Float32(real_sum)
spectrum[2 * k + 1] = Float32(imag_sum)
fn bit_reverse_kernel(
input: UnsafePointer[Float32],
output: UnsafePointer[Float32],
N: Int32,
log2N: Int32,
):
"""Bit-reverse permutation kernel."""
var idx = Int(global_idx.x)
if idx >= Int(N):
return
var reversed = 0
var temp = idx
for _ in range(Int(log2N)):
reversed = (reversed << 1) | (temp & 1)
temp >>= 1
output[2 * reversed] = input[2 * idx]
output[2 * reversed + 1] = input[2 * idx + 1]
fn fft_stage_kernel(
data: UnsafePointer[Float32],
N: Int32,
stage: Int32,
):
"""FFT butterfly computation for a specific stage using Float64 intermediate calculations."""
var idx = Int(global_idx.x)
var stage_size = 1 << Int(stage)
var group_size = stage_size << 1
var num_groups = Int(N) // group_size
if idx >= Int(N) // 2:
return
var group = idx // stage_size
var pos = idx % stage_size
if group >= num_groups:
return
var base = group * group_size
var i = base + pos
var j = base + pos + stage_size
if i >= Int(N) or j >= Int(N):
return
# Compute twiddle factor in Float64 for precision
var angle = -2.0 * M_PI * Float64(pos) / Float64(group_size)
var wr = libdevice_cos(angle)
var wi = libdevice_sin(angle)
var ar = Float64(data[2 * i])
var ai = Float64(data[2 * i + 1])
var br = Float64(data[2 * j])
var bi = Float64(data[2 * j + 1])
# Complex multiply in Float64
var tr = br * wr - bi * wi
var ti = br * wi + bi * wr
# Butterfly in Float64, convert to Float32 at output
data[2 * i] = Float32(ar + tr)
data[2 * i + 1] = Float32(ai + ti)
data[2 * j] = Float32(ar - tr)
data[2 * j + 1] = Float32(ai - ti)
@export
def solve(signal: UnsafePointer[Float32], spectrum: UnsafePointer[Float32], N: Int32):
"""Main FFT solver that dispatches to appropriate algorithm."""
if N <= 0:
return
var is_power_of_2 = (N & (N - 1)) == 0
var ctx = DeviceContext()
# Create device buffers and copy input data
var signal_device = ctx.enqueue_create_buffer[DType.float32](Int(2 * N))
var spectrum_device = ctx.enqueue_create_buffer[DType.float32](Int(2 * N))
ctx.enqueue_copy(signal_device.unsafe_ptr(), signal, Int(2 * N))
ctx.synchronize()
if is_power_of_2 and N >= 2:
var log2N: Int32 = 0
var temp = N
while temp > 1:
temp >>= 1
log2N += 1
var temp_buffer = ctx.enqueue_create_buffer[DType.float32](Int(2 * N))
alias THREADS_PER_BLOCK = 256
var blocks = ceildiv(Int(N), THREADS_PER_BLOCK)
ctx.enqueue_function_unchecked[bit_reverse_kernel](
signal_device,
temp_buffer,
N,
log2N,
grid_dim=(blocks,),
block_dim=(THREADS_PER_BLOCK,),
)
ctx.synchronize()
for stage in range(Int(log2N)):
var stage_blocks = ceildiv(Int(N) // 2, THREADS_PER_BLOCK)
if stage_blocks > 0:
ctx.enqueue_function_unchecked[fft_stage_kernel](
temp_buffer,
N,
Int32(stage),
grid_dim=(stage_blocks,),
block_dim=(THREADS_PER_BLOCK,),
)
ctx.synchronize()
ctx.enqueue_copy(spectrum, temp_buffer.unsafe_ptr(), Int(2 * N))
ctx.synchronize()
else:
alias THREADS_PER_BLOCK = 256
var blocks = ceildiv(Int(N), THREADS_PER_BLOCK)
ctx.enqueue_function_unchecked[dft_kernel](
signal_device,
spectrum_device,
N,
grid_dim=(blocks,),
block_dim=(THREADS_PER_BLOCK,),
)
ctx.synchronize()
ctx.enqueue_copy(spectrum, spectrum_device.unsafe_ptr(), Int(2 * N))
ctx.synchronize()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment