Skip to content

Instantly share code, notes, and snippets.

@ghorn
Created May 9, 2018 22:27
Show Gist options
  • Save ghorn/724f24b48dc8b9dd6105fe3bdd9b732a to your computer and use it in GitHub Desktop.
Save ghorn/724f24b48dc8b9dd6105fe3bdd9b732a to your computer and use it in GitHub Desktop.
SBV interpolation wrap stand-along example
{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE ScopedTypeVariables #-}
import Data.List ( intercalate )
import qualified Data.Vector as V
import qualified Data.Vector.Storable as SV
import qualified Data.Vector.Storable.Mutable as SVM
import Data.SBV
import Data.SBV.Tools.CodeGen
-- all the data needed to call the foreign interpolation function
data CInterpolantData
= CInterpolantData
{ ciNDims :: Int
, ciNOutputs :: Int
, ciGrid :: SV.Vector Double
, ciOffset :: SV.Vector Int64
, ciValues :: SV.Vector Double
, ciLookupModes :: SV.Vector Int64
, ciNumIW :: Int
, ciNumW :: Int
}
sampleData :: CInterpolantData
sampleData =
CInterpolantData
{ ciNDims = 2
, ciNOutputs = 3
, ciGrid = SV.fromList [0, 1, 2, 0, 10]
, ciOffset = SV.fromList [0, 3, 5]
, ciValues = SV.fromList [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6]
, ciLookupModes = SV.fromList [0, 0]
, ciNumIW = 4
, ciNumW = 2
}
-- this wraps a foreign import ccall and unsafePerformIO
wrapForeignInterpolation :: CInterpolantData -> [Double] -> [Double]
wrapForeignInterpolation = undefined
interpolate :: CInterpolantData -> [SBV Double] -> [SBV Double]
interpolate dataToHardcode = cgUninterpret funName c h
where
-- this wraps a foreign import - there is no way to do it natively
h :: [SBV Double] -> [SBV Double]
h i = undefined -- map literal . wrapForeignInterpolation . map unliteral
funName = "interpolate"
c :: [String]
c =
[ "#include \"foreign_interpolation.h\""
, ""
, "double " ++ funName ++ "(const double * const input, double * output) {"
, " // hardcoded inputs"
, " const int64_t ndim = " ++ show (ciNDims dataToHardcode) ++ ";"
, " const double grid[] = " ++ showArray (ciGrid dataToHardcode) ++ ";"
, " const int64_t offset[] = " ++ showArray (ciOffset dataToHardcode) ++ ";"
, " const double values[] = " ++ showArray (ciValues dataToHardcode) ++ ";"
, " const int64_t lookup_mode[] = " ++ showArray (ciLookupModes dataToHardcode) ++ ";"
, " const int64_t noutputs = " ++ show (ciNOutputs dataToHardcode) ++ ";"
, ""
, " // mutable work arrays"
, " int64_t iw[" ++ show (ciNumIW dataToHardcode) ++ "];"
, " double w[" ++ show (ciNumW dataToHardcode) ++ "];"
, ""
, " // call my function"
, " foreign_interpolation(output, ndim, grid, offset, values, input, lookup_mode, noutputs, iw, w);"
, "}"
]
where
showArray x = "{" ++ intercalate ", " (map show (SV.toList x)) ++ "}"
doCg :: CInterpolantData -> SBVCodeGen ()
doCg cinterpolantData = do
inputs <- cgInputArr 2 "in"
let y :: [SBV Double]
y = interpolate cinterpolantData inputs
cgOutputArr "out" y
main :: IO ()
main = compileToC (Just "test") "test" (doCg sampleData)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment