Last active
May 31, 2025 19:23
-
-
Save mistivia/3cd81cc29fc611c104edee3599ba674c to your computer and use it in GitHub Desktop.
Borrow like Rust in Linear Haskell
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| {-# LANGUAGE ForeignFunctionInterface #-} | |
| {-# LANGUAGE LinearTypes #-} | |
| {-# LANGUAGE QualifiedDo #-} | |
| import Data.Type.Bool | |
| import Foreign.Ptr (Ptr) | |
| import Foreign.C.Types (CDouble) | |
| import Foreign.Marshal.Alloc (free) | |
| import Data.Function ((&)) | |
| import Prelude (IO, (>>), (>>=), fmap, return) | |
| import Prelude.Linear | |
| import qualified System.IO.Linear as Linear | |
| import qualified Control.Functor.Linear as Linear | |
| ------------------------------------ | |
| foreign import ccall "new_matrix" | |
| cNewMatrix :: IO (Ptr CDouble) | |
| -- double* new_matrix() { | |
| -- double* mat = (double*)malloc(N * N * sizeof(double)); | |
| -- return mat; | |
| -- } | |
| foreign import ccall "fill_matrix" | |
| cFillMatrix :: Ptr CDouble -> IO () | |
| -- void fill_matrix(double* mat) { | |
| -- for (int i = 0; i < N; i++) { | |
| -- for (int j = 0; j < N; j++) { | |
| -- mat[i*N + j] = (double)rand() / RAND_MAX; | |
| -- } | |
| -- } | |
| -- } | |
| foreign import ccall "mat_mul" | |
| cMatMul :: Ptr CDouble -> Ptr CDouble -> Ptr CDouble -> IO () | |
| -- void mat_mul(double* c, double* a, double* b) { | |
| -- cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N, N, N, 1.0, a, N, b, N, 1.0, c, N); | |
| -- } | |
| ------------------------------------ | |
| data Mat where Mat :: (Ptr CDouble) -> Mat | |
| data MatRef where MatRef :: (Ptr CDouble) -> MatRef | |
| newMatrix :: Linear.IO Mat | |
| newMatrix = Linear.fromSystemIO $ fmap Mat cNewMatrix | |
| deleteMat :: Mat %1 -> Linear.IO () | |
| deleteMat (Mat ptr) = Linear.fromSystemIO $ free ptr | |
| fillMat :: MatRef -> IO () | |
| fillMat (MatRef ptr) = cFillMatrix ptr | |
| matMul :: MatRef -> MatRef -> MatRef -> IO () | |
| matMul (MatRef a) (MatRef b) (MatRef c) = cMatMul a b c | |
| ------------------------------------ | |
| class Borrow io b where | |
| borrow :: Mat %1 -> (MatRef -> io b) %1-> Linear.IO (Mat, b) | |
| instance Borrow Linear.IO a where | |
| borrow :: Mat %1 -> (MatRef -> Linear.IO b) %1-> Linear.IO (Mat, b) | |
| borrow (Mat ptr) body = | |
| body (MatRef ptr) Linear.>>= \x-> | |
| Linear.return (Mat ptr, x) | |
| instance (a ~ ()) => Borrow IO a where | |
| borrow :: Mat %1 -> (MatRef -> IO ()) %1-> Linear.IO (Mat, ()) | |
| borrow (Mat ptr) body = | |
| Linear.fromSystemIO (body (MatRef ptr)) Linear.>>= \x-> | |
| Linear.return (Mat ptr, x) | |
| ------------------------------------ | |
| main = Linear.withLinearIO $ Linear.do | |
| a <- newMatrix | |
| b <- newMatrix | |
| c <- newMatrix | |
| (a, (b, (c, ()))) <- | |
| borrow a $ \a -> | |
| borrow b $ \b -> | |
| borrow c $ \c -> do | |
| fillMat a | |
| fillMat b | |
| matMul c a b | |
| deleteMat a | |
| deleteMat b | |
| deleteMat c | |
| Linear.return (Ur ()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment