Last active
October 5, 2021 15:57
-
-
Save maedoc/ac1d04aab9bec22cb019f74fe67ebefe to your computer and use it in GitHub Desktop.
Generating explicit DIF FFT kernels with Sympy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, ctypes | |
import numpy as np | |
from sympy import exp, pi, I, Symbol, re, im, cse | |
from sympy.printing.c import ccode | |
# DIF FT recursion | |
def ft(x): | |
N = len(x) | |
if N == 2: | |
a, b = x | |
return a + b, a - b | |
else: | |
E, O = ft(x[::2]), ft(x[1::2]) | |
# TODO double check k/N here | |
Ws = np.array([exp(-2*pi*I*k/N) for k in range(N//2)]) | |
return np.r_[E + Ws*O, E - Ws*O] | |
# generate real & imag terms for specific N | |
N = 16 | |
y = np.array([Symbol(f'y{k}', real=True) for k in range(N)]) | |
Y = ft(y).tolist() | |
Yri = [re(_) for _ in Y] + [im(_) for _ in Y] | |
# check it matches numerical fft | |
y0 = np.random.randn(N) | |
Y0 = np.array([_.subs(zip(y, y0)).evalf() for _ in Yri]).astype('f') | |
Y1 = np.fft.fft(y0) | |
Y1 = np.r_[Y1.real, Y1.imag] | |
np.testing.assert_allclose(Y0, Y1) | |
# generate code via cse | |
aux, exp = cse(Yri) | |
with open('dif16.c', 'w') as fd: | |
print('''#include <math.h> | |
void dif16(float * __restrict y, float * __restrict Y) { | |
#pragma clang loop vectorize(enable) | |
for (int i=0; i<16; i++) {''', file=fd) | |
for i in range(N): | |
print(f'float y{i} = y[16*{i} + i];', file=fd) | |
for l, r in aux: | |
print(f'float {l} = {ccode(r)};', file=fd) | |
for i, e in enumerate(exp): | |
print(f'Y[{N}*{i} + i] = {ccode(e)};', file=fd) | |
print('} }', file=fd) | |
# compile with clang -O3 -mavx2 -mavx -ffast-math | |
# if available, -mavx512f as well | |
flags = '-O3 -mavx2 -mavx -ffast-math' | |
os.system(f'clang {flags} -fPIC -c dif16.c') | |
# generates ~300 asm lines w/ zmm similar length to icc | |
os.system('clang -shared dif16.o -o dif16.so -lm') | |
# test via ctypes against np.fft.fft | |
lib = ctypes.CDLL('./dif16.so') | |
f32_t = np.ctypeslib.ndpointer(dtype=np.float32, ndim=1, flags='C_CONTIGUOUS') | |
lib.dif16.argtypes = [f32_t, f32_t] | |
z = np.random.randn(N*16).astype('f') | |
Z0 = np.zeros(N*2*16, 'f') | |
lib.dif16(z, Z0) | |
Z1 = np.fft.fft(z.reshape((N, 16)),axis=0) | |
np.testing.assert_allclose(np.r_[Z1.real.flat[:], Z1.imag.flat[:]], Z0, rtol=1e-4, atol=1e-6) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sympy import exp, pi, I, Symbol, cse, re, im | |
from sympy.printing.c import ccode | |
def ft(x): | |
from sympy import exp, pi, I | |
N = len(x) | |
if N == 2: | |
a, b = x | |
return a + b, a - b | |
else: | |
E, O = ft(x[::2]), ft(x[1::2]) | |
Ws = np.array([exp(-2*pi*I*k/N) for k in range(N//2)]) | |
return np.r_[E + Ws*O, E - Ws*O] | |
# build radix-N dif fft symbolically | |
N = 256 | |
y = np.array([Symbol(f'y{k}', real=True) for k in range(N)]) | |
Y = ft(y).tolist() | |
Yri = [re(_) for _ in Y] + [im(_) for _ in Y] # separate real & im parts | |
# choose number of transforms to do in kernel | |
M = 128 | |
# generate code via cse | |
aux, exp = cse(Yri) | |
with open('dif.c', 'w') as fd: | |
print('''#include <math.h> | |
void dif(float * __restrict y, float * __restrict Y) { | |
#pragma clang loop vectorize(assume_safety)''', file=fd) | |
print(f'for (int i=0; i<{M}; i++) {{', file=fd) | |
for i in range(N): | |
print(f'float y{i} = y[{M}*{i} + i];', file=fd) | |
for l, r in aux: | |
print(f'float {l} = {ccode(r)};', file=fd) | |
for i, e in enumerate(exp): | |
print(f'Y[{M}*{i} + i] = {ccode(e)};', file=fd) | |
print('} }', file=fd) | |
# compile with clang -O3 -mavx2 -mavx -mavx512f -ffast-math | |
flags = '-O3 -mavx2 -mavx -ffast-math -fopenmp' | |
os.system(f'clang {flags} -fPIC -c dif.c') | |
# generates ~300 asm lines w/ zmm similar length to icc | |
os.system('clang -shared dif.o -o dif.so -fopenmp -lm') | |
from ctypes import CDLL | |
lib = CDLL('./dif.so') | |
f32_t = np.ctypeslib.ndpointer(dtype=np.float32, ndim=1, flags='C_CONTIGUOUS') | |
lib.dif.argtypes = [f32_t, f32_t] | |
z = np.random.randn(N*M).astype('f') | |
Z0 = np.zeros(N*2*M, 'f') | |
lib.dif(z, Z0) | |
Z1 = np.fft.fft(z.reshape((N, M)),axis=0) | |
%timeit lib.dif(z, Z0) | |
z2 = z.reshape((N,M)) | |
%timeit np.fft.fft(z,axis=0) | |
# about 6-7x faster than numpy |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sympy import exp, pi, I, Symbol, cse, re, im, Indexed | |
from sympy.printing.c import ccode | |
import time | |
import sys | |
N, = [int(_) for _ in sys.argv[1:]] | |
tic = time.time() | |
def ft(x): | |
N = len(x) | |
if N == 2: | |
a, b = x | |
return a + b, a - b | |
else: | |
E, O = ft(x[::2]), ft(x[1::2]) | |
Ws = np.array([exp(-2*pi*I*k/N) for k in range(N//2)]) | |
return np.r_[E + Ws*O, E - Ws*O] | |
# build radix-N dif fft symbolically | |
# print(f'building radix-{N} fft t={time.time()-tic}') | |
y = np.array([Symbol(f'y[{k}]') for k in range(N)]) | |
# y = np.array([Indexed('y', 'k')[k] for k in range(N)]) | |
Y = ft(y).tolist() | |
Yri = [re(_) for _ in Y] + [im(_) for _ in Y] # separate real & im parts | |
# for i,yi in enumerate(Y): | |
# print(i, yi) | |
# 1/0 | |
aux, exprs = cse(Y) | |
print(f'let fft [n] (y:[n](f32,f32)): [n](f32,f32) =') | |
print(' let re (a,b) = a') | |
print(' let im (a,b) = b') | |
for l, r in aux: | |
print(f' let {l} = ({re(r)}, {im(r)})') | |
print(' in [') | |
for i, _ in enumerate(exprs): | |
print(f' ({re(_)}, {im(_)})', ',' if i<(len(Y)-1) else '') | |
print(' ] :> [n](f32,f32)') | |
# # generate Futhark code via cse | |
# print(f'doing cse t={time.time()-tic}') | |
# aux, exprs = cse(Yri) | |
# print(f'generating code t={time.time()-tic}') | |
# with open(f'fft{N}.fut', 'w') as fd: | |
# print('open f32', file=fd) | |
# print('''let M_PI = pi | |
# let M_SQRT2 = sqrt(2) | |
# let M_SQRT1_2 = 1/sqrt(2)''', file=fd) | |
# print(f'let fft{N} (y:[{N}]f32): [{2*N}]f32 = ', file=fd) | |
# for i in range(N): | |
# print(f' let y{i} = y[{i}]', file=fd) | |
# for l, r in aux: | |
# print(f' let {l} = {ccode(r)}', file=fd) | |
# print(' in', file=fd) | |
# print(' [ ', ', '.join([ccode(_) for _ in exprs]), ']', file=fd) | |
# print(f'done t={time.time()-tic}') | |
# TODO we should use shorter radix and combine them after, | |
# since we will want to transform 6x2x2 not scalar |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- construct higher radix FFTs via module system | |
module type fft2_t = { | |
module R: real | |
-- type of half of transform i.e. (R.t,R.t) for N=2 | |
type t | |
-- length of transform and number of reals involved | |
val N: i64 | |
val Nr: i64 | |
-- twiddle factors for stage N used by apply at stage 2*N | |
val W: (t, t) | |
-- ops on stage N used by apply at stage 2*N | |
val t0: t | |
val +: t -> t -> t | |
val -: t -> t -> t | |
val *: t -> t -> t | |
-- apply N point transform | |
val apply: (t, t) -> (t, t) | |
-- conversion between t & array of time points | |
val a2t: [N](R.t,R.t) -> (t, t) | |
val t2a: (t, t) -> [N](R.t,R.t) | |
-- conversion between t & array of reals | |
val ar2t: [Nr]R.t -> (t, t) | |
val t2ar: (t, t) -> [Nr]R.t | |
} | |
-- base case with two point transform | |
module mk_fft2 (R: real) = { | |
module R = R | |
type t = (R.t, R.t) | |
let t0: t = (R.i32 0, R.i32 0) | |
let N = 2i64 | |
let Nr = 4i64 | |
let W = ((R.i64 1,R.i64 0), (R.i64 (1-2),R.i64 0)) | |
let (+) (ar,ai) (br,bi) = (ar R.+ br, ai R.+ bi) | |
let (-) (ar,ai) (br,bi) = (ar R.- br, ai R.- bi) | |
-- complex multiplication | |
let (*) (a,b) (c,d) = R.((a * c - b * d, b * c + a * d)) | |
let apply ((a, b): (t, t)): (t, t) = (a + b, a - b) | |
let ar2t (ab:[Nr]R.t): (t, t) = ((ab[0], ab[1]),(ab[2], ab[3])) | |
let t2ar ((a,b),(c,d)): [Nr]R.t = [a,b,c,d] :> [Nr]R.t | |
let a2t (ab:[N](R.t,R.t)): (t, t) = (ab[0], ab[1]) | |
let t2a (a,b): [N](R.t,R.t) = [a,b] :> [N](R.t,R.t) | |
} | |
-- compute 2*N transform given module computing N point transform | |
module mk_fft2n (F: fft2_t) = { | |
module R = F.R | |
type t = (F.t, F.t) | |
type w = (R.t, R.t) | |
let t0 = (F.t0, F.t0) | |
let N = 2 * F.N | |
let Nr = 2 * F.Nr | |
let (+) (ar,ai) (br,bi) = (ar F.+ br, ai F.+ bi) | |
let (-) (ar,ai) (br,bi) = (ar F.- br, ai F.- bi) | |
-- elementwise, not using complex formula | |
let (*) (ar,ai) (br,bi) = (ar F.* br, ai F.* bi) | |
let ar2t (x:[Nr]R.t) = (F.ar2t (x[:F.Nr] :> [F.Nr]R.t), F.ar2t (x[F.Nr:] :> [F.Nr]R.t)) | |
let t2ar (a,b): [Nr]R.t = ((F.t2ar a) ++ (F.t2ar b)) :> [Nr]R.t | |
let a2t (ab:[N](R.t,R.t)) = (F.a2t (ab[:F.N] :> [F.N]w), F.a2t (ab[F.N:] :> [F.N]w)) | |
let t2a (a,b) = ((F.t2a a) ++ (F.t2a b)) :> [N]w | |
-- cos(2*pi*k/N), -sin(2*pi*k/N) | |
let ws = iota N | |
|> map (\k -> R.((i64 2)*pi*(i64 k)/(i64 N))) | |
|> map (\x -> (R.cos x, R.neg (R.sin x))) | |
let W = a2t ws | |
let apply (a, b) = | |
-- split a & b into even and odd time time points | |
let ab: [N]w = t2a (a, b) | |
let e: t = F.a2t (ab[0::2]:>[F.N]w) | |
let o: t = F.a2t (ab[1::2]:>[F.N]w) | |
-- do subtransforms on even and odd | |
let E: t = F.apply e | |
let O: t = F.apply o | |
let L: t = E + W.0 * O | |
let R: t = E - W.0 * O | |
in (L, R) | |
} | |
-- compute 2*N transform given module computing N point transform w/ arrays | |
module mk_fft2na (F: fft2_t) = { | |
module R = F.R | |
type t = [2]F.t | |
type w = (R.t, R.t) | |
let t0 = [F.t0, F.t0] | |
let N = 2 * F.N | |
let Nr = 2 * F.Nr | |
let (+) as bs = [as[0] F.+ bs[0], as[1] F.+ bs[1]] | |
let (-) as bs = [as[0] F.- bs[0], as[1] F.- bs[1]] | |
let (*) as bs = [as[0] F.* bs[0], as[1] F.* bs[1]] | |
-- cos(2*pi*k/N), -sin(2*pi*k/N) | |
let ws = iota N | |
|> map (\k -> R.((i64 2)*pi*(i64 k)/(i64 N))) | |
|> map (\x -> (R.cos x, R.neg (R.sin x))) | |
let split ((a,b):(F.t,F.t)): t = [a,b] | |
let W: t = split(F.a2t (ws[:N/2]:>[F.N]w)) | |
let apply ab: [4]F.t = | |
-- split a & b into even and odd time time points | |
let ab: [N]w = ((F.t2a ab[0]) ++ (F.t2a ab[1])) :> [N]w | |
let e: (F.t,F.t) = F.a2t (ab[0::2]:>[F.N]w) | |
let o: (F.t,F.t) = F.a2t (ab[1::2]:>[F.N]w) | |
-- do subtransforms on even and odd | |
let E: t = split (F.apply e) | |
let O: t = split (F.apply o) | |
let L: t = E + W * O | |
let R: t = E - W * O | |
in ((L ++ R) :> [4]F.t) | |
} | |
module fft2 = mk_fft2 f32 | |
module fft4 = mk_fft2n fft2 | |
module fft8 = mk_fft2n fft4 | |
module fft16 = mk_fft2n fft8 | |
module fft32 = mk_fft2n fft16 | |
module fft64 = mk_fft2n fft32 | |
module fft128 = mk_fft2n fft64 | |
-- mkfft2n fft128 causes OpenCL drivers to crash | |
-- so we use an array version instead | |
module fft256 = mk_fft2na fft128 | |
-- == | |
-- input { 128i64 } | |
-- input { 768i64 } | |
entry main (M:i64) = | |
let x = tabulate_2d M fft256.Nr (\i j -> (fft256.R.i64 i) fft256.R.+ (fft256.R.i64 j)) :> [M][fft256.Nr]fft256.R.t | |
let y = map (\xi -> [fft128.ar2t (xi[:256]:>[fft128.Nr]fft128.R.t), fft128.ar2t (xi[256:]:>[fft128.Nr]fft128.R.t)]) x | |
in map fft256.apply y | |
-- == | |
-- input { 128i64 } | |
-- input { 768i64 } | |
entry main128 (M:i64) = | |
let x = tabulate_2d M fft128.Nr (\i j -> (fft128.R.i64 i) fft128.R.+ (fft128.R.i64 j)) :> [M][fft128.Nr]fft128.R.t | |
let y = map fft128.ar2t x | |
in map fft128.apply y | |
-- about half speed of genfut.py version, but it's doing full complex fft | |
-- so the input (128x256) is twice as large ~2MB vs ~1MB. | |
-- would be work doing an rfft variant with real input data | |
-- worse, tho is that it appears too large to function on NVIDIA OpenCL (3MB .c file!) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- plain radix2 w/ f32 | |
type c32 = (f32, f32) | |
let cadd ((a,b):c32) ((c,d):c32): c32 = (a + c, b + d) | |
let csub ((a,b):c32) ((c,d):c32): c32 = (a - c, b - d) | |
let cmul ((a,b):c32) ((c,d):c32): c32 = ((a * c - b * d, b * c + a * d)) | |
let Cadd as bs = map2 cadd as bs | |
let Csub as bs = map2 csub as bs | |
let Cmul as bs = map2 cmul as bs | |
let log2 (n:i64) = (loop (r,n)=(0i64,n) while 1 < n do (r + 1, n / 2)).0 | |
-- can't do recursion, don't want bit reversal by hand | |
-- modules? | |
module type fftl2n = { | |
val N: i64 | |
val fft [n]: [n]c32 -> [n]c32 | |
} | |
module fftl2 = { | |
let N = 2i64 | |
let fft [n] (x:[n]c32): [n]c32 = [cadd x[0] x[1], csub x[0] x[1]] :> [n]c32 | |
} | |
module mk_fftl2n (F: fftl2n) = { | |
let N = 2 * F.N | |
let W = iota N | |
|> map (\k -> f32.pi*(f32.i64 k)/(f32.i64 N)) | |
|> map (\x -> (f32.cos x, -(f32.sin x))) | |
let fft [n] (x:[n]c32): [n]c32 = | |
let n2 = n / 2 | |
let E = F.fft x[0::2] :> [n2]c32 | |
let O = F.fft x[1::2] :> [n2]c32 | |
let L = Cadd E (Cmul W[:n2] O) | |
let R = Csub E (Cmul W[:n2] O) | |
in (L ++ R) :> [n]c32 | |
} | |
module fft128 = mk_fftl2n (mk_fftl2n (mk_fftl2n (mk_fftl2n (mk_fftl2n (mk_fftl2n fftl2))))) | |
-- since we want 256 real valued transform, | |
-- we can reuse a 128 complex transform | |
-- (with some extra missing but cheap bits ) | |
-- it's also more obvious how to iFFT | |
-- == | |
-- input { 128i64 } | |
-- input { 768i64 } | |
entry main (m:i64): [m][128](f32,f32) = | |
let x:[m][128](f32,f32) = tabulate_2d m 128i64 (\i j -> ((f32.i64 i), (f32.i64 j))) | |
in | |
map fft128.fft x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from futhark_ffi import Futhark | |
import _fft128x256 | |
fft = Futhark(_fft128x256) | |
import numpy as np | |
x = np.random.randn(128, 256).astype('f') | |
y1 = fft.from_futhark(fft.ondata(x)) | |
y1r, y1i = y1[:,:256], y1[:,256:] | |
y2 = np.fft.fft(x, axis=1) | |
y2r, y2i = y2.real, y2.imag | |
np.testing.assert_allclose(y1r, y2r, rtol=1e-3, atol=1e-5) | |
np.testing.assert_allclose(y1i, y2i, rtol=1e-3, atol=1e-5) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from futhark_ffi import Futhark | |
import os | |
os.system('futhark c --library modrecur.fut') | |
os.system('build_futhark_ffi modrecur') | |
import _modrecur | |
fft = Futhark(_modrecur) | |
N = 128 | |
z = np.random.randn(128).astype(np.complex64) | |
Z = np.fft.fft(z) | |
z_ = np.c_[z.real, z.imag].flat[:] | |
Z2_ = fft.from_futhark(fft.fft128(z_)) | |
Z2 = np.zeros_like(Z) | |
Z2.real[:] = Z2_[0::2] | |
Z2.imag[:] = Z2_[1::2] | |
np.testing.assert_allclose(Z, Z2) |
at N=256, it takes 109s to run ft(y)
& clang's autovectorizer complains (source file is 190k!). also at 128. 64 is ok. If we vectorize(assume_safety)
then the large N are ok.
Large N in single function don't work with OpenMP (in clang at least)
can generate Futhark kernels with
aux, exp = cse(Yri)
print('open f32')
print('''let M_PI = pi
let M_SQRT2 = sqrt(2)
let M_SQRT1_2 = 1/sqrt(2)''')
print('type vec = (', ','.join(['f32' for _ in range(N)]), ')')
print('type vec2 = (', ','.join(['f32' for _ in range(N*2)]), ')')
print('let main (y:vec): vec2 = ')
print(' let (', ','.join([f'y{i}' for i in range(N)]), ') = y')
for l, r in aux:
print(f' let {l} = {ccode(r)}')
print(' in')
print(' (', ','.join([ccode(_) for _ in exp]), ')')
The current code is a decimation in time, not frequency.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
It would be nice to have a Stockham form as well, to expose more stride-1-compatible parallelism.