Created
September 28, 2021 17:06
-
-
Save maedoc/a423bf5b8aee3076744ca3b81e4773ab to your computer and use it in GitHub Desktop.
Futhark kernel for triangular matrix generation with recursion down columns [wip]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- we want to verify we can do the irregular structure of the sht | |
-- when mapped on the GPU. | |
-- if we can't, we have (lmax+1)*lmax/2 wasted space, as follows | |
let lt1 [nlat] (lmax:i64) (m:i64) (fm:[nlat]f32): [lmax]f32 = | |
let q = tabulate lmax (\_ -> 0f32) | |
let (q, _) = | |
loop (q, x) = (q, 0f32) for i < (lmax - m) do | |
let m' = m + i | |
let q[m'] = reduce (+) x fm |> (*x) | |
in (q, x+1) | |
in q | |
-- do single sht | |
let sht1 [nlat] [nlon] (lmax:i64) (fm:[nlat][nlon]f32): [][]f32 = | |
map2 (lt1 lmax) (iota lmax) (transpose fm)[:lmax] | |
-- do 3 SHTs | |
entry main1: [][][]f32 = | |
let lmax = 16i64 | |
let nlat = (lmax/2+1)*2 -- max lmax for nlat | |
let nlon = nlat * 2 | |
-- field with DFT applied along nlon axis | |
let fm = tabulate_3d 3 nlat nlon (\i j k -> 0f32) | |
in map (sht1 lmax) fm | |
----------------------------------------------------------------------- | |
-- need to read https://futhark-book.readthedocs.io/en/latest/irregular-flattening.html | |
let lt2' [nlat] (M:i64) (fm:[nlat]f32): [M]f32 = | |
let q = tabulate M (\_ -> 0f32) | |
let (q, _) = | |
loop (q, x) = (q, 0f32) for i < M do | |
let q[i] = reduce (+) x fm |> (*x) | |
in (q, x+1) | |
in q | |
let lt2 [nlat] (lmax:i64) (m:i64) (fm:[nlat]f32) (out:*[lmax]f32): *[lmax]f32 = | |
let M = lmax - m | |
let out[:M] = lt2' M fm | |
in | |
out | |
-- -- do single sht | |
let sht2 [nlat] [nlon] (lmax:i64) (fm:[nlat][nlon]f32): [][]f32 = | |
map2 lt2 (iota lmax) (transpose fm)[:lmax] | |
-- -- do 3 SHTs | |
-- entry main2: [][][]f32 = | |
-- let lmax = 16i64 | |
-- let nlat = (lmax/2+1)*2 -- max lmax for nlat | |
-- let nlon = nlat * 2 | |
-- -- field with DFT applied along nlon axis | |
-- let fm = tabulate_3d 3 nlat nlon (\i j k -> 0f32) | |
-- in map (sht1 lmax) fm |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment