Skip to content

Instantly share code, notes, and snippets.

@maedoc
Last active October 5, 2021 15:57
Show Gist options
  • Save maedoc/ac1d04aab9bec22cb019f74fe67ebefe to your computer and use it in GitHub Desktop.
Save maedoc/ac1d04aab9bec22cb019f74fe67ebefe to your computer and use it in GitHub Desktop.
Generating explicit DIF FFT kernels with Sympy
import os, ctypes
import numpy as np
from sympy import exp, pi, I, Symbol, re, im, cse
from sympy.printing.c import ccode
# DIF FT recursion
def ft(x):
N = len(x)
if N == 2:
a, b = x
return a + b, a - b
else:
E, O = ft(x[::2]), ft(x[1::2])
# TODO double check k/N here
Ws = np.array([exp(-2*pi*I*k/N) for k in range(N//2)])
return np.r_[E + Ws*O, E - Ws*O]
# generate real & imag terms for specific N
N = 16
y = np.array([Symbol(f'y{k}', real=True) for k in range(N)])
Y = ft(y).tolist()
Yri = [re(_) for _ in Y] + [im(_) for _ in Y]
# check it matches numerical fft
y0 = np.random.randn(N)
Y0 = np.array([_.subs(zip(y, y0)).evalf() for _ in Yri]).astype('f')
Y1 = np.fft.fft(y0)
Y1 = np.r_[Y1.real, Y1.imag]
np.testing.assert_allclose(Y0, Y1)
# generate code via cse
aux, exp = cse(Yri)
with open('dif16.c', 'w') as fd:
print('''#include <math.h>
void dif16(float * __restrict y, float * __restrict Y) {
#pragma clang loop vectorize(enable)
for (int i=0; i<16; i++) {''', file=fd)
for i in range(N):
print(f'float y{i} = y[16*{i} + i];', file=fd)
for l, r in aux:
print(f'float {l} = {ccode(r)};', file=fd)
for i, e in enumerate(exp):
print(f'Y[{N}*{i} + i] = {ccode(e)};', file=fd)
print('} }', file=fd)
# compile with clang -O3 -mavx2 -mavx -ffast-math
# if available, -mavx512f as well
flags = '-O3 -mavx2 -mavx -ffast-math'
os.system(f'clang {flags} -fPIC -c dif16.c')
# generates ~300 asm lines w/ zmm similar length to icc
os.system('clang -shared dif16.o -o dif16.so -lm')
# test via ctypes against np.fft.fft
lib = ctypes.CDLL('./dif16.so')
f32_t = np.ctypeslib.ndpointer(dtype=np.float32, ndim=1, flags='C_CONTIGUOUS')
lib.dif16.argtypes = [f32_t, f32_t]
z = np.random.randn(N*16).astype('f')
Z0 = np.zeros(N*2*16, 'f')
lib.dif16(z, Z0)
Z1 = np.fft.fft(z.reshape((N, 16)),axis=0)
np.testing.assert_allclose(np.r_[Z1.real.flat[:], Z1.imag.flat[:]], Z0, rtol=1e-4, atol=1e-6)
import numpy as np
from sympy import exp, pi, I, Symbol, cse, re, im
from sympy.printing.c import ccode
def ft(x):
from sympy import exp, pi, I
N = len(x)
if N == 2:
a, b = x
return a + b, a - b
else:
E, O = ft(x[::2]), ft(x[1::2])
Ws = np.array([exp(-2*pi*I*k/N) for k in range(N//2)])
return np.r_[E + Ws*O, E - Ws*O]
# build radix-N dif fft symbolically
N = 256
y = np.array([Symbol(f'y{k}', real=True) for k in range(N)])
Y = ft(y).tolist()
Yri = [re(_) for _ in Y] + [im(_) for _ in Y] # separate real & im parts
# choose number of transforms to do in kernel
M = 128
# generate code via cse
aux, exp = cse(Yri)
with open('dif.c', 'w') as fd:
print('''#include <math.h>
void dif(float * __restrict y, float * __restrict Y) {
#pragma clang loop vectorize(assume_safety)''', file=fd)
print(f'for (int i=0; i<{M}; i++) {{', file=fd)
for i in range(N):
print(f'float y{i} = y[{M}*{i} + i];', file=fd)
for l, r in aux:
print(f'float {l} = {ccode(r)};', file=fd)
for i, e in enumerate(exp):
print(f'Y[{M}*{i} + i] = {ccode(e)};', file=fd)
print('} }', file=fd)
# compile with clang -O3 -mavx2 -mavx -mavx512f -ffast-math
flags = '-O3 -mavx2 -mavx -ffast-math -fopenmp'
os.system(f'clang {flags} -fPIC -c dif.c')
# generates ~300 asm lines w/ zmm similar length to icc
os.system('clang -shared dif.o -o dif.so -fopenmp -lm')
from ctypes import CDLL
lib = CDLL('./dif.so')
f32_t = np.ctypeslib.ndpointer(dtype=np.float32, ndim=1, flags='C_CONTIGUOUS')
lib.dif.argtypes = [f32_t, f32_t]
z = np.random.randn(N*M).astype('f')
Z0 = np.zeros(N*2*M, 'f')
lib.dif(z, Z0)
Z1 = np.fft.fft(z.reshape((N, M)),axis=0)
%timeit lib.dif(z, Z0)
z2 = z.reshape((N,M))
%timeit np.fft.fft(z,axis=0)
# about 6-7x faster than numpy
import numpy as np
from sympy import exp, pi, I, Symbol, cse, re, im, Indexed
from sympy.printing.c import ccode
import time
import sys
N, = [int(_) for _ in sys.argv[1:]]
tic = time.time()
def ft(x):
N = len(x)
if N == 2:
a, b = x
return a + b, a - b
else:
E, O = ft(x[::2]), ft(x[1::2])
Ws = np.array([exp(-2*pi*I*k/N) for k in range(N//2)])
return np.r_[E + Ws*O, E - Ws*O]
# build radix-N dif fft symbolically
# print(f'building radix-{N} fft t={time.time()-tic}')
y = np.array([Symbol(f'y[{k}]') for k in range(N)])
# y = np.array([Indexed('y', 'k')[k] for k in range(N)])
Y = ft(y).tolist()
Yri = [re(_) for _ in Y] + [im(_) for _ in Y] # separate real & im parts
# for i,yi in enumerate(Y):
# print(i, yi)
# 1/0
aux, exprs = cse(Y)
print(f'let fft [n] (y:[n](f32,f32)): [n](f32,f32) =')
print(' let re (a,b) = a')
print(' let im (a,b) = b')
for l, r in aux:
print(f' let {l} = ({re(r)}, {im(r)})')
print(' in [')
for i, _ in enumerate(exprs):
print(f' ({re(_)}, {im(_)})', ',' if i<(len(Y)-1) else '')
print(' ] :> [n](f32,f32)')
# # generate Futhark code via cse
# print(f'doing cse t={time.time()-tic}')
# aux, exprs = cse(Yri)
# print(f'generating code t={time.time()-tic}')
# with open(f'fft{N}.fut', 'w') as fd:
# print('open f32', file=fd)
# print('''let M_PI = pi
# let M_SQRT2 = sqrt(2)
# let M_SQRT1_2 = 1/sqrt(2)''', file=fd)
# print(f'let fft{N} (y:[{N}]f32): [{2*N}]f32 = ', file=fd)
# for i in range(N):
# print(f' let y{i} = y[{i}]', file=fd)
# for l, r in aux:
# print(f' let {l} = {ccode(r)}', file=fd)
# print(' in', file=fd)
# print(' [ ', ', '.join([ccode(_) for _ in exprs]), ']', file=fd)
# print(f'done t={time.time()-tic}')
# TODO we should use shorter radix and combine them after,
# since we will want to transform 6x2x2 not scalar
-- construct higher radix FFTs via module system
module type fft2_t = {
module R: real
-- type of half of transform i.e. (R.t,R.t) for N=2
type t
-- length of transform and number of reals involved
val N: i64
val Nr: i64
-- twiddle factors for stage N used by apply at stage 2*N
val W: (t, t)
-- ops on stage N used by apply at stage 2*N
val t0: t
val +: t -> t -> t
val -: t -> t -> t
val *: t -> t -> t
-- apply N point transform
val apply: (t, t) -> (t, t)
-- conversion between t & array of time points
val a2t: [N](R.t,R.t) -> (t, t)
val t2a: (t, t) -> [N](R.t,R.t)
-- conversion between t & array of reals
val ar2t: [Nr]R.t -> (t, t)
val t2ar: (t, t) -> [Nr]R.t
}
-- base case with two point transform
module mk_fft2 (R: real) = {
module R = R
type t = (R.t, R.t)
let t0: t = (R.i32 0, R.i32 0)
let N = 2i64
let Nr = 4i64
let W = ((R.i64 1,R.i64 0), (R.i64 (1-2),R.i64 0))
let (+) (ar,ai) (br,bi) = (ar R.+ br, ai R.+ bi)
let (-) (ar,ai) (br,bi) = (ar R.- br, ai R.- bi)
-- complex multiplication
let (*) (a,b) (c,d) = R.((a * c - b * d, b * c + a * d))
let apply ((a, b): (t, t)): (t, t) = (a + b, a - b)
let ar2t (ab:[Nr]R.t): (t, t) = ((ab[0], ab[1]),(ab[2], ab[3]))
let t2ar ((a,b),(c,d)): [Nr]R.t = [a,b,c,d] :> [Nr]R.t
let a2t (ab:[N](R.t,R.t)): (t, t) = (ab[0], ab[1])
let t2a (a,b): [N](R.t,R.t) = [a,b] :> [N](R.t,R.t)
}
-- compute 2*N transform given module computing N point transform
module mk_fft2n (F: fft2_t) = {
module R = F.R
type t = (F.t, F.t)
type w = (R.t, R.t)
let t0 = (F.t0, F.t0)
let N = 2 * F.N
let Nr = 2 * F.Nr
let (+) (ar,ai) (br,bi) = (ar F.+ br, ai F.+ bi)
let (-) (ar,ai) (br,bi) = (ar F.- br, ai F.- bi)
-- elementwise, not using complex formula
let (*) (ar,ai) (br,bi) = (ar F.* br, ai F.* bi)
let ar2t (x:[Nr]R.t) = (F.ar2t (x[:F.Nr] :> [F.Nr]R.t), F.ar2t (x[F.Nr:] :> [F.Nr]R.t))
let t2ar (a,b): [Nr]R.t = ((F.t2ar a) ++ (F.t2ar b)) :> [Nr]R.t
let a2t (ab:[N](R.t,R.t)) = (F.a2t (ab[:F.N] :> [F.N]w), F.a2t (ab[F.N:] :> [F.N]w))
let t2a (a,b) = ((F.t2a a) ++ (F.t2a b)) :> [N]w
-- cos(2*pi*k/N), -sin(2*pi*k/N)
let ws = iota N
|> map (\k -> R.((i64 2)*pi*(i64 k)/(i64 N)))
|> map (\x -> (R.cos x, R.neg (R.sin x)))
let W = a2t ws
let apply (a, b) =
-- split a & b into even and odd time time points
let ab: [N]w = t2a (a, b)
let e: t = F.a2t (ab[0::2]:>[F.N]w)
let o: t = F.a2t (ab[1::2]:>[F.N]w)
-- do subtransforms on even and odd
let E: t = F.apply e
let O: t = F.apply o
let L: t = E + W.0 * O
let R: t = E - W.0 * O
in (L, R)
}
-- compute 2*N transform given module computing N point transform w/ arrays
module mk_fft2na (F: fft2_t) = {
module R = F.R
type t = [2]F.t
type w = (R.t, R.t)
let t0 = [F.t0, F.t0]
let N = 2 * F.N
let Nr = 2 * F.Nr
let (+) as bs = [as[0] F.+ bs[0], as[1] F.+ bs[1]]
let (-) as bs = [as[0] F.- bs[0], as[1] F.- bs[1]]
let (*) as bs = [as[0] F.* bs[0], as[1] F.* bs[1]]
-- cos(2*pi*k/N), -sin(2*pi*k/N)
let ws = iota N
|> map (\k -> R.((i64 2)*pi*(i64 k)/(i64 N)))
|> map (\x -> (R.cos x, R.neg (R.sin x)))
let split ((a,b):(F.t,F.t)): t = [a,b]
let W: t = split(F.a2t (ws[:N/2]:>[F.N]w))
let apply ab: [4]F.t =
-- split a & b into even and odd time time points
let ab: [N]w = ((F.t2a ab[0]) ++ (F.t2a ab[1])) :> [N]w
let e: (F.t,F.t) = F.a2t (ab[0::2]:>[F.N]w)
let o: (F.t,F.t) = F.a2t (ab[1::2]:>[F.N]w)
-- do subtransforms on even and odd
let E: t = split (F.apply e)
let O: t = split (F.apply o)
let L: t = E + W * O
let R: t = E - W * O
in ((L ++ R) :> [4]F.t)
}
module fft2 = mk_fft2 f32
module fft4 = mk_fft2n fft2
module fft8 = mk_fft2n fft4
module fft16 = mk_fft2n fft8
module fft32 = mk_fft2n fft16
module fft64 = mk_fft2n fft32
module fft128 = mk_fft2n fft64
-- mkfft2n fft128 causes OpenCL drivers to crash
-- so we use an array version instead
module fft256 = mk_fft2na fft128
-- ==
-- input { 128i64 }
-- input { 768i64 }
entry main (M:i64) =
let x = tabulate_2d M fft256.Nr (\i j -> (fft256.R.i64 i) fft256.R.+ (fft256.R.i64 j)) :> [M][fft256.Nr]fft256.R.t
let y = map (\xi -> [fft128.ar2t (xi[:256]:>[fft128.Nr]fft128.R.t), fft128.ar2t (xi[256:]:>[fft128.Nr]fft128.R.t)]) x
in map fft256.apply y
-- ==
-- input { 128i64 }
-- input { 768i64 }
entry main128 (M:i64) =
let x = tabulate_2d M fft128.Nr (\i j -> (fft128.R.i64 i) fft128.R.+ (fft128.R.i64 j)) :> [M][fft128.Nr]fft128.R.t
let y = map fft128.ar2t x
in map fft128.apply y
-- about half speed of genfut.py version, but it's doing full complex fft
-- so the input (128x256) is twice as large ~2MB vs ~1MB.
-- would be work doing an rfft variant with real input data
-- worse, tho is that it appears too large to function on NVIDIA OpenCL (3MB .c file!)
-- plain radix2 w/ f32
type c32 = (f32, f32)
let cadd ((a,b):c32) ((c,d):c32): c32 = (a + c, b + d)
let csub ((a,b):c32) ((c,d):c32): c32 = (a - c, b - d)
let cmul ((a,b):c32) ((c,d):c32): c32 = ((a * c - b * d, b * c + a * d))
let Cadd as bs = map2 cadd as bs
let Csub as bs = map2 csub as bs
let Cmul as bs = map2 cmul as bs
let log2 (n:i64) = (loop (r,n)=(0i64,n) while 1 < n do (r + 1, n / 2)).0
-- can't do recursion, don't want bit reversal by hand
-- modules?
module type fftl2n = {
val N: i64
val fft [n]: [n]c32 -> [n]c32
}
module fftl2 = {
let N = 2i64
let fft [n] (x:[n]c32): [n]c32 = [cadd x[0] x[1], csub x[0] x[1]] :> [n]c32
}
module mk_fftl2n (F: fftl2n) = {
let N = 2 * F.N
let W = iota N
|> map (\k -> f32.pi*(f32.i64 k)/(f32.i64 N))
|> map (\x -> (f32.cos x, -(f32.sin x)))
let fft [n] (x:[n]c32): [n]c32 =
let n2 = n / 2
let E = F.fft x[0::2] :> [n2]c32
let O = F.fft x[1::2] :> [n2]c32
let L = Cadd E (Cmul W[:n2] O)
let R = Csub E (Cmul W[:n2] O)
in (L ++ R) :> [n]c32
}
module fft128 = mk_fftl2n (mk_fftl2n (mk_fftl2n (mk_fftl2n (mk_fftl2n (mk_fftl2n fftl2)))))
-- since we want 256 real valued transform,
-- we can reuse a 128 complex transform
-- (with some extra missing but cheap bits )
-- it's also more obvious how to iFFT
-- ==
-- input { 128i64 }
-- input { 768i64 }
entry main (m:i64): [m][128](f32,f32) =
let x:[m][128](f32,f32) = tabulate_2d m 128i64 (\i j -> ((f32.i64 i), (f32.i64 j)))
in
map fft128.fft x
from futhark_ffi import Futhark
import _fft128x256
fft = Futhark(_fft128x256)
import numpy as np
x = np.random.randn(128, 256).astype('f')
y1 = fft.from_futhark(fft.ondata(x))
y1r, y1i = y1[:,:256], y1[:,256:]
y2 = np.fft.fft(x, axis=1)
y2r, y2i = y2.real, y2.imag
np.testing.assert_allclose(y1r, y2r, rtol=1e-3, atol=1e-5)
np.testing.assert_allclose(y1i, y2i, rtol=1e-3, atol=1e-5)
import numpy as np
from futhark_ffi import Futhark
import os
os.system('futhark c --library modrecur.fut')
os.system('build_futhark_ffi modrecur')
import _modrecur
fft = Futhark(_modrecur)
N = 128
z = np.random.randn(128).astype(np.complex64)
Z = np.fft.fft(z)
z_ = np.c_[z.real, z.imag].flat[:]
Z2_ = fft.from_futhark(fft.fft128(z_))
Z2 = np.zeros_like(Z)
Z2.real[:] = Z2_[0::2]
Z2.imag[:] = Z2_[1::2]
np.testing.assert_allclose(Z, Z2)
@maedoc
Copy link
Author

maedoc commented Sep 28, 2021

It would be nice to have a Stockham form as well, to expose more stride-1-compatible parallelism.

@maedoc
Copy link
Author

maedoc commented Sep 28, 2021

at N=256, it takes 109s to run ft(y) & clang's autovectorizer complains (source file is 190k!). also at 128. 64 is ok. If we vectorize(assume_safety) then the large N are ok.

@maedoc
Copy link
Author

maedoc commented Sep 28, 2021

Large N in single function don't work with OpenMP (in clang at least)

@maedoc
Copy link
Author

maedoc commented Sep 28, 2021

can generate Futhark kernels with

aux, exp = cse(Yri)
print('open f32')
print('''let M_PI = pi
let M_SQRT2 = sqrt(2)
let M_SQRT1_2 = 1/sqrt(2)''')
print('type vec = (', ','.join(['f32' for _ in range(N)]), ')')
print('type vec2 = (', ','.join(['f32' for _ in range(N*2)]), ')')
print('let main (y:vec): vec2 = ')
print('  let (', ','.join([f'y{i}' for i in range(N)]), ') = y')
for l, r in aux:
    print(f'  let {l} = {ccode(r)}')
print('  in')
print('  (', ','.join([ccode(_) for _ in exp]), ')')

@maedoc
Copy link
Author

maedoc commented Oct 2, 2021

The current code is a decimation in time, not frequency.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment