Created
March 23, 2022 08:58
-
-
Save el-hult/dc1d364ca120593dc7a633c32702c7fe to your computer and use it in GitHub Desktop.
A Hasekll module that can store vectors from hmatrix as double arrays, suitable for passing into Fortran and other FFI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE KindSignatures #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
module Pointers where | |
import Data.Proxy (Proxy (..)) | |
import Foreign.Marshal.Array (newArray, peekArray) | |
import Foreign.Ptr (Ptr) | |
import GHC.TypeNats (KnownNat, Nat, natVal) | |
import Numeric.LinearAlgebra.Data (toList) | |
import Numeric.LinearAlgebra.Static ( vector, Sized(unwrap), R ) | |
-- | A pointer to a Real vector of dimension n. | |
-- pick out the raw pointer with 'unRPtr rPtr' | |
-- wrap a pointer using '(RPtr ptr) :: RPtr 7' | |
newtype RPtr (n :: Nat) = RPtr {unRPtr :: Ptr Double} | |
deriving (Show) | |
-- | Store a hmatrix sized vector from a pointer to a double-array | |
-- Example | |
-- store (vector [1,2,3] :: R 3) | |
store :: KnownNat n => R n -> IO (RPtr n) | |
store v = do | |
ptr <- newArray . toList . unwrap $ v | |
return $ RPtr ptr | |
-- | Loads a hmatrix sized vector from a pointer to a double-array | |
-- Example: | |
-- load ((RPtr ptr) :: RPtr 7) | |
load :: forall n. KnownNat n => RPtr n -> IO (R n) | |
load rPtr = do | |
let k = fromIntegral $ natVal (Proxy :: Proxy n) | |
doubles <- peekArray k (unRPtr rPtr) | |
return (vector doubles :: R n) | |
-- | Example of loading and storing | |
main :: IO () | |
main = do | |
-- Example for loading and storing sized pointers | |
let a = vector [1,2,3] :: R 3 | |
print a | |
rPtr <- store a | |
print rPtr | |
b <- load rPtr | |
print b | |
-- Example for conversion from/to simple pointer | |
let ptr = unRPtr rPtr | |
print ptr | |
c <- load (RPtr ptr :: RPtr 2) | |
print c |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment