Created
November 22, 2024 18:17
-
-
Save gallais/ca8441e4ea74e64f8347837d655b18f9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE TypeAbstractions #-} | |
{-# LANGUAGE TypeApplications #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
import Control.Concurrent (forkIO, MVar, newEmptyMVar, putMVar, takeMVar) | |
import Data.Kind (Type) | |
import Data.Type.Equality ((:~:)(..), testEquality) | |
import GHC.Exts (Any) | |
import GHC.TypeLits (SomeSymbol(..), Symbol, KnownSymbol(..)) | |
import Unsafe.Coerce (unsafeCoerce) | |
data Pair a b = MkPair a b | |
type family Fst (p :: Pair a b) :: a where | |
Fst (MkPair x y) = x | |
type family Snd (p :: Pair a b) :: b where | |
Snd (MkPair x y) = y | |
type family Assoc (key :: a) (abs :: [Pair a b]) :: b where | |
Assoc key (MkPair key b : _) = b | |
Assoc key (_ : abs) = Assoc key abs | |
type MsgType = Pair Symbol (Pair Type Protocol) | |
data Protocol :: Type where | |
Done :: Protocol | |
Send :: [MsgType] -> Protocol | |
Recv :: [MsgType] -> Protocol | |
-- Fixpoints | |
More :: Protocol | |
Loop :: Protocol -> Protocol | |
type family Subst (p :: Protocol) (q :: Protocol) where | |
Subst p Done = Done | |
Subst p More = p | |
Subst p (Send cs) = Send (Substs p cs) | |
Subst p (Recv cs) = Recv (Substs p cs) | |
type family Substs (p :: Protocol) (qs :: [MsgType]) :: [MsgType] where | |
Substs p '[] = '[] | |
Substs p (MkPair s (MkPair ty q) : qs) | |
= (MkPair s (MkPair ty (Subst p q))) : Substs p qs | |
type family Force (p :: Protocol) :: Protocol where | |
Force (Loop q) = Subst (Loop q) q | |
Force p = p | |
type family Dual (p :: Protocol) :: Protocol where | |
Dual Done = Done | |
Dual (Send cs) = Recv (Duals cs) | |
Dual (Recv cs) = Send (Duals cs) | |
Dual (Loop p) = Loop (Dual p) | |
Dual More = More | |
type family Duals (qs :: [MsgType]) where | |
Duals '[] = '[] | |
Duals (MkPair s (MkPair ty q) : qs) | |
= (MkPair s (MkPair ty (Dual q))) : Duals qs | |
data Msg (cs :: [MsgType]) (key :: Symbol) where | |
MkMsg :: Fst (Assoc key cs) -> Msg cs key | |
data Letter where | |
MkLetter :: SomeSymbol -> Any -> Letter | |
data Channel (ps :: Protocol) | |
= MkChannel | |
{ sending :: MVar Letter | |
, receiving :: MVar Letter | |
} | |
type Logger = Force (Loop (Send | |
[ MkPair "LOG" (MkPair String More) | |
, MkPair "STOP" (MkPair () Done) | |
])) | |
send :: KnownSymbol key | |
=> Channel (Send cs) | |
-> Msg cs key | |
-> IO (Channel (Force (Snd (Assoc key cs)))) | |
send (MkChannel sending receiving) (MkMsg @key msg) | |
= do putMVar sending (MkLetter (SomeSymbol @key undefined) (unsafeCoerce msg)) | |
pure (MkChannel sending receiving) | |
data Reply cs where | |
MkReply :: forall cs key. | |
KnownSymbol key | |
=> Msg cs key | |
-> Channel ((Force (Snd (Assoc key cs)))) | |
-> Reply cs | |
fork :: IO (Channel p, Channel (Dual p)) | |
fork = do | |
sending <- newEmptyMVar | |
receiving <- newEmptyMVar | |
pure (MkChannel sending receiving, MkChannel receiving sending) | |
recv :: Channel (Recv cs) -> IO (Reply cs) | |
recv (MkChannel sending receiving) = do | |
MkLetter (SomeSymbol @key p) msg <- takeMVar receiving | |
pure (MkReply (MkMsg @key (unsafeCoerce msg)) (MkChannel sending receiving)) | |
close :: Channel Done -> IO () | |
close _ = pure () | |
logMsg :: Channel Logger -> String -> IO (Channel Logger) | |
logMsg ch msg = send ch (MkMsg @"LOG" msg) | |
logStop :: Channel Logger -> IO (Channel Done) | |
logStop ch = send ch (MkMsg @"STOP" ()) | |
logger :: Channel (Dual Logger) -> IO () | |
logger ch = do | |
MkReply (MkMsg @key msg) ch <- recv ch | |
case | |
( testEquality (symbolSing @key) (symbolSing @"LOG") | |
, testEquality (symbolSing @key) (symbolSing @"STOP") | |
) of | |
(Just Refl, _) -> do putStrLn msg | |
logger ch | |
(_, Just Refl) -> close ch | |
_ -> error "The IMPOSSIBLE has happened" | |
run :: (Channel p -> IO a) | |
-> (Channel (Dual p) -> IO ()) | |
-> IO a | |
run kl kr = do | |
(chl, chr) <- fork | |
forkIO (kr chr) | |
kl chl | |
withLogger :: (Channel Logger -> IO a) -> IO a | |
withLogger k = run k logger | |
count :: Channel Logger -> IO () | |
count ch = do | |
ch <- logMsg ch "Hello" | |
ch <- logMsg ch "World" | |
ch <- logStop ch | |
close ch | |
type RecSum = Force (Loop (Send | |
[ MkPair "SUM" (MkPair [Int] (Recv '[ MkPair "RESULT" (MkPair Int More) ])) | |
, MkPair "STOP" (MkPair () Done) | |
])) | |
recSum :: [Int] -> Channel RecSum -> IO () | |
recSum [] ch = do | |
ch <- send ch (MkMsg @"STOP" ()) | |
close ch | |
recSum ns@(_ : tl) ch = do | |
ch <- send ch (MkMsg @"SUM" ns) | |
MkReply (MkMsg @key msg) ch <- recv ch | |
case testEquality (symbolSing @key) (symbolSing @"RESULT") of | |
Just Refl -> do putStrLn ("Sum of " ++ show ns ++ " is " ++ show msg) | |
recSum tl ch | |
_ -> error "The IMPOSSIBLE happened" | |
sumRec :: Channel (Dual RecSum) -> IO () | |
sumRec ch = do | |
MkReply (MkMsg @key msg) ch <- recv ch | |
case ( testEquality (symbolSing @key) (symbolSing @"SUM") | |
, testEquality (symbolSing @key) (symbolSing @"STOP") | |
) of | |
(Just Refl, _) -> do ch <- send ch (MkMsg @"RESULT" (sum msg)) | |
sumRec ch | |
(_, Just Refl) -> close ch | |
_ -> error "The IMPOSSIBLE happened" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment