Created
May 9, 2018 22:27
-
-
Save ghorn/724f24b48dc8b9dd6105fe3bdd9b732a to your computer and use it in GitHub Desktop.
SBV interpolation wrap stand-along example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wall #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
import Data.List ( intercalate ) | |
import qualified Data.Vector as V | |
import qualified Data.Vector.Storable as SV | |
import qualified Data.Vector.Storable.Mutable as SVM | |
import Data.SBV | |
import Data.SBV.Tools.CodeGen | |
-- all the data needed to call the foreign interpolation function | |
data CInterpolantData | |
= CInterpolantData | |
{ ciNDims :: Int | |
, ciNOutputs :: Int | |
, ciGrid :: SV.Vector Double | |
, ciOffset :: SV.Vector Int64 | |
, ciValues :: SV.Vector Double | |
, ciLookupModes :: SV.Vector Int64 | |
, ciNumIW :: Int | |
, ciNumW :: Int | |
} | |
sampleData :: CInterpolantData | |
sampleData = | |
CInterpolantData | |
{ ciNDims = 2 | |
, ciNOutputs = 3 | |
, ciGrid = SV.fromList [0, 1, 2, 0, 10] | |
, ciOffset = SV.fromList [0, 3, 5] | |
, ciValues = SV.fromList [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6] | |
, ciLookupModes = SV.fromList [0, 0] | |
, ciNumIW = 4 | |
, ciNumW = 2 | |
} | |
-- this wraps a foreign import ccall and unsafePerformIO | |
wrapForeignInterpolation :: CInterpolantData -> [Double] -> [Double] | |
wrapForeignInterpolation = undefined | |
interpolate :: CInterpolantData -> [SBV Double] -> [SBV Double] | |
interpolate dataToHardcode = cgUninterpret funName c h | |
where | |
-- this wraps a foreign import - there is no way to do it natively | |
h :: [SBV Double] -> [SBV Double] | |
h i = undefined -- map literal . wrapForeignInterpolation . map unliteral | |
funName = "interpolate" | |
c :: [String] | |
c = | |
[ "#include \"foreign_interpolation.h\"" | |
, "" | |
, "double " ++ funName ++ "(const double * const input, double * output) {" | |
, " // hardcoded inputs" | |
, " const int64_t ndim = " ++ show (ciNDims dataToHardcode) ++ ";" | |
, " const double grid[] = " ++ showArray (ciGrid dataToHardcode) ++ ";" | |
, " const int64_t offset[] = " ++ showArray (ciOffset dataToHardcode) ++ ";" | |
, " const double values[] = " ++ showArray (ciValues dataToHardcode) ++ ";" | |
, " const int64_t lookup_mode[] = " ++ showArray (ciLookupModes dataToHardcode) ++ ";" | |
, " const int64_t noutputs = " ++ show (ciNOutputs dataToHardcode) ++ ";" | |
, "" | |
, " // mutable work arrays" | |
, " int64_t iw[" ++ show (ciNumIW dataToHardcode) ++ "];" | |
, " double w[" ++ show (ciNumW dataToHardcode) ++ "];" | |
, "" | |
, " // call my function" | |
, " foreign_interpolation(output, ndim, grid, offset, values, input, lookup_mode, noutputs, iw, w);" | |
, "}" | |
] | |
where | |
showArray x = "{" ++ intercalate ", " (map show (SV.toList x)) ++ "}" | |
doCg :: CInterpolantData -> SBVCodeGen () | |
doCg cinterpolantData = do | |
inputs <- cgInputArr 2 "in" | |
let y :: [SBV Double] | |
y = interpolate cinterpolantData inputs | |
cgOutputArr "out" y | |
main :: IO () | |
main = compileToC (Just "test") "test" (doCg sampleData) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment