Skip to content

Instantly share code, notes, and snippets.

@gallais
Created November 22, 2024 18:17
Show Gist options
  • Save gallais/ca8441e4ea74e64f8347837d655b18f9 to your computer and use it in GitHub Desktop.
Save gallais/ca8441e4ea74e64f8347837d655b18f9 to your computer and use it in GitHub Desktop.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
import Control.Concurrent (forkIO, MVar, newEmptyMVar, putMVar, takeMVar)
import Data.Kind (Type)
import Data.Type.Equality ((:~:)(..), testEquality)
import GHC.Exts (Any)
import GHC.TypeLits (SomeSymbol(..), Symbol, KnownSymbol(..))
import Unsafe.Coerce (unsafeCoerce)
data Pair a b = MkPair a b
type family Fst (p :: Pair a b) :: a where
Fst (MkPair x y) = x
type family Snd (p :: Pair a b) :: b where
Snd (MkPair x y) = y
type family Assoc (key :: a) (abs :: [Pair a b]) :: b where
Assoc key (MkPair key b : _) = b
Assoc key (_ : abs) = Assoc key abs
type MsgType = Pair Symbol (Pair Type Protocol)
data Protocol :: Type where
Done :: Protocol
Send :: [MsgType] -> Protocol
Recv :: [MsgType] -> Protocol
-- Fixpoints
More :: Protocol
Loop :: Protocol -> Protocol
type family Subst (p :: Protocol) (q :: Protocol) where
Subst p Done = Done
Subst p More = p
Subst p (Send cs) = Send (Substs p cs)
Subst p (Recv cs) = Recv (Substs p cs)
type family Substs (p :: Protocol) (qs :: [MsgType]) :: [MsgType] where
Substs p '[] = '[]
Substs p (MkPair s (MkPair ty q) : qs)
= (MkPair s (MkPair ty (Subst p q))) : Substs p qs
type family Force (p :: Protocol) :: Protocol where
Force (Loop q) = Subst (Loop q) q
Force p = p
type family Dual (p :: Protocol) :: Protocol where
Dual Done = Done
Dual (Send cs) = Recv (Duals cs)
Dual (Recv cs) = Send (Duals cs)
Dual (Loop p) = Loop (Dual p)
Dual More = More
type family Duals (qs :: [MsgType]) where
Duals '[] = '[]
Duals (MkPair s (MkPair ty q) : qs)
= (MkPair s (MkPair ty (Dual q))) : Duals qs
data Msg (cs :: [MsgType]) (key :: Symbol) where
MkMsg :: Fst (Assoc key cs) -> Msg cs key
data Letter where
MkLetter :: SomeSymbol -> Any -> Letter
data Channel (ps :: Protocol)
= MkChannel
{ sending :: MVar Letter
, receiving :: MVar Letter
}
type Logger = Force (Loop (Send
[ MkPair "LOG" (MkPair String More)
, MkPair "STOP" (MkPair () Done)
]))
send :: KnownSymbol key
=> Channel (Send cs)
-> Msg cs key
-> IO (Channel (Force (Snd (Assoc key cs))))
send (MkChannel sending receiving) (MkMsg @key msg)
= do putMVar sending (MkLetter (SomeSymbol @key undefined) (unsafeCoerce msg))
pure (MkChannel sending receiving)
data Reply cs where
MkReply :: forall cs key.
KnownSymbol key
=> Msg cs key
-> Channel ((Force (Snd (Assoc key cs))))
-> Reply cs
fork :: IO (Channel p, Channel (Dual p))
fork = do
sending <- newEmptyMVar
receiving <- newEmptyMVar
pure (MkChannel sending receiving, MkChannel receiving sending)
recv :: Channel (Recv cs) -> IO (Reply cs)
recv (MkChannel sending receiving) = do
MkLetter (SomeSymbol @key p) msg <- takeMVar receiving
pure (MkReply (MkMsg @key (unsafeCoerce msg)) (MkChannel sending receiving))
close :: Channel Done -> IO ()
close _ = pure ()
logMsg :: Channel Logger -> String -> IO (Channel Logger)
logMsg ch msg = send ch (MkMsg @"LOG" msg)
logStop :: Channel Logger -> IO (Channel Done)
logStop ch = send ch (MkMsg @"STOP" ())
logger :: Channel (Dual Logger) -> IO ()
logger ch = do
MkReply (MkMsg @key msg) ch <- recv ch
case
( testEquality (symbolSing @key) (symbolSing @"LOG")
, testEquality (symbolSing @key) (symbolSing @"STOP")
) of
(Just Refl, _) -> do putStrLn msg
logger ch
(_, Just Refl) -> close ch
_ -> error "The IMPOSSIBLE has happened"
run :: (Channel p -> IO a)
-> (Channel (Dual p) -> IO ())
-> IO a
run kl kr = do
(chl, chr) <- fork
forkIO (kr chr)
kl chl
withLogger :: (Channel Logger -> IO a) -> IO a
withLogger k = run k logger
count :: Channel Logger -> IO ()
count ch = do
ch <- logMsg ch "Hello"
ch <- logMsg ch "World"
ch <- logStop ch
close ch
type RecSum = Force (Loop (Send
[ MkPair "SUM" (MkPair [Int] (Recv '[ MkPair "RESULT" (MkPair Int More) ]))
, MkPair "STOP" (MkPair () Done)
]))
recSum :: [Int] -> Channel RecSum -> IO ()
recSum [] ch = do
ch <- send ch (MkMsg @"STOP" ())
close ch
recSum ns@(_ : tl) ch = do
ch <- send ch (MkMsg @"SUM" ns)
MkReply (MkMsg @key msg) ch <- recv ch
case testEquality (symbolSing @key) (symbolSing @"RESULT") of
Just Refl -> do putStrLn ("Sum of " ++ show ns ++ " is " ++ show msg)
recSum tl ch
_ -> error "The IMPOSSIBLE happened"
sumRec :: Channel (Dual RecSum) -> IO ()
sumRec ch = do
MkReply (MkMsg @key msg) ch <- recv ch
case ( testEquality (symbolSing @key) (symbolSing @"SUM")
, testEquality (symbolSing @key) (symbolSing @"STOP")
) of
(Just Refl, _) -> do ch <- send ch (MkMsg @"RESULT" (sum msg))
sumRec ch
(_, Just Refl) -> close ch
_ -> error "The IMPOSSIBLE happened"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment