Created
September 8, 2024 10:52
-
-
Save expipiplus1/cfd5c4fb4a5a40338ccf8642fb3d0f1e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Rock.Memo where | |
import Control.Concurrent.Lifted | |
import Control.Exception.Lifted | |
import Control.Monad | |
import Data.Dependent.HashMap (DHashMap) | |
import Data.Dependent.HashMap qualified as DHashMap | |
import Data.Foldable | |
import Data.GADT.Compare (GEq) | |
import Data.GADT.Show (GShow) | |
import Data.HashMap.Lazy (HashMap) | |
import Data.HashMap.Lazy qualified as HashMap | |
import Data.Hashable | |
import Data.IORef.Lifted | |
import Data.Kind (Type) | |
import Data.Maybe | |
import Data.Some | |
import Data.Typeable | |
import Effectful (Eff, IOE, raise, (:>)) | |
import Rock | |
-- * Implicit memoisation- | |
-- | Proof that every key permits IO | |
class HasIOE f where | |
withIOE :: f es a -> (IOE :> es => Eff es a) -> Eff es a | |
-- | Remember what @f@ queries have already been performed and their results in | |
-- a 'DHashMap', and reuse them if a query is performed again a second time. | |
-- | |
-- The 'DHashMap' should typically not be reused if there has been some change that | |
-- might make a query return a different result. | |
memoise | |
:: forall f | |
. (forall es. GEq (f es), forall es a. Hashable (f es a), HasIOE f) | |
=> IORef (DHashMap (HideEffects f) MVar) | |
-> Rules f | |
-> Rules f | |
memoise startedVar rules (key :: f es a) = withIOE key $ do | |
maybeValueVar <- DHashMap.lookup (HideEffects key) <$> readIORef startedVar | |
case maybeValueVar of | |
Nothing -> do | |
valueVar <- newEmptyMVar | |
join $ atomicModifyIORef startedVar $ \started -> | |
case DHashMap.alterLookup (Just . fromMaybe valueVar) (HideEffects key) started of | |
(Nothing, started') -> | |
( started' | |
, do | |
value <- rules key | |
putMVar valueVar value | |
return value | |
) | |
(Just valueVar', _started') -> | |
(started, readMVar valueVar') | |
Just valueVar -> | |
readMVar valueVar | |
-- * Explicit memoisation | |
data MemoQuery f es a where | |
MemoQuery :: f es a -> MemoQuery f (IOE : es) a | |
-- Don't actually memoise anything | |
withoutMemoisation :: Rules f -> Rules (MemoQuery f) | |
withoutMemoisation r (MemoQuery key) = raise $ r key | |
-- | Remember what @f@ queries have already been performed and their results in | |
-- a 'DHashMap', and reuse them if a query is performed again a second time. | |
-- | |
-- The 'DHashMap' should typically not be reused if there has been some change that | |
-- might make a query return a different result. | |
memoiseExplicit | |
:: forall f | |
. (forall es. GEq (f es), forall es a. Hashable (f es a)) | |
=> IORef (DHashMap (HideEffects f) MVar) | |
-> Rules f | |
-> Rules (MemoQuery f) | |
memoiseExplicit startedVar rules (MemoQuery (key :: f es a)) = do | |
maybeValueVar <- DHashMap.lookup (HideEffects key) <$> readIORef startedVar | |
case maybeValueVar of | |
Nothing -> do | |
valueVar <- newEmptyMVar | |
join $ atomicModifyIORef startedVar $ \started -> | |
case DHashMap.alterLookup (Just . fromMaybe valueVar) (HideEffects key) started of | |
(Nothing, started') -> | |
( started' | |
, do | |
value <- raise $ rules key | |
putMVar valueVar value | |
return value | |
) | |
(Just valueVar', _started') -> | |
(started, readMVar valueVar') | |
Just valueVar -> | |
readMVar valueVar | |
newtype Cyclic f = Cyclic (Some f) | |
deriving (Show) | |
instance (GShow f, Typeable f) => Exception (Cyclic (f :: Type -> Type)) | |
data MemoEntry a | |
= Started !ThreadId !(MVar (Maybe a)) !(MVar (Maybe [ThreadId])) | |
| Done !a | |
-- | Like 'memoise', but throw @'Cyclic' f@ if a query depends on itself, directly or | |
-- indirectly. | |
-- | |
-- The 'HashMap' represents dependencies between threads and should not be | |
-- reused between invocations. | |
memoiseWithCycleDetection | |
:: forall f | |
. ( Typeable f | |
, forall es a. Show (f es a) | |
, forall es. GEq (f es) | |
, forall es a. Hashable (f es a) | |
) | |
=> IORef (DHashMap (HideEffects f) MemoEntry) | |
-> IORef (HashMap ThreadId ThreadId) | |
-> Rules f | |
-> Rules (MemoQuery f) | |
memoiseWithCycleDetection startedVar depsVar rules = rules' | |
where | |
rules' (MemoQuery (key :: f es a)) = do | |
maybeEntry <- DHashMap.lookup (HideEffects key) <$> readIORef startedVar | |
case maybeEntry of | |
Nothing -> do | |
threadId <- myThreadId | |
valueVar <- newEmptyMVar | |
waitVar <- newMVar $ Just [] | |
join $ atomicModifyIORef startedVar $ \started -> | |
case DHashMap.alterLookup (Just . fromMaybe (Started threadId valueVar waitVar)) (HideEffects key) started of | |
(Nothing, started') -> | |
( started' | |
, ( do | |
value <- raise $ rules key | |
join $ modifyMVar waitVar $ \maybeWaitingThreads -> do | |
case maybeWaitingThreads of | |
Nothing -> | |
error "impossible" | |
Just waitingThreads -> | |
return | |
( Nothing | |
, atomicModifyIORef depsVar $ \deps -> | |
( foldl' (flip HashMap.delete) deps waitingThreads | |
, () | |
) | |
) | |
atomicModifyIORef startedVar $ \started'' -> | |
(DHashMap.insert (HideEffects key) (Done value) started'', ()) | |
putMVar valueVar $ Just value | |
return value | |
) | |
`catch` \(e :: Cyclic (HideEffects f)) -> do | |
atomicModifyIORef startedVar $ \started'' -> | |
(DHashMap.delete (HideEffects key) started'', ()) | |
putMVar valueVar Nothing | |
throwIO e | |
) | |
(Just entry, _started') -> | |
(started, waitFor entry) | |
Just entry -> waitFor entry | |
where | |
waitFor entry = | |
case entry of | |
Started onThread valueVar waitVar -> do | |
threadId <- myThreadId | |
modifyMVar_ waitVar $ \maybeWaitingThreads -> do | |
case maybeWaitingThreads of | |
Nothing -> | |
return maybeWaitingThreads | |
Just waitingThreads -> do | |
join $ atomicModifyIORef depsVar $ \deps -> do | |
let deps' = HashMap.insert threadId onThread deps | |
if detectCycle threadId deps' | |
then | |
( deps | |
, throwIO $ Cyclic $ Some (HideEffects key) | |
) | |
else | |
( deps' | |
, return () | |
) | |
return $ Just $ threadId : waitingThreads | |
maybeValue <- readMVar valueVar | |
maybe (rules' (MemoQuery key)) return maybeValue | |
Done value -> | |
return value | |
detectCycle threadId deps = | |
go threadId | |
where | |
go tid = | |
case HashMap.lookup tid deps of | |
Nothing -> False | |
Just dep | |
| dep == threadId -> True | |
| otherwise -> go dep |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE UndecidableInstances #-} | |
{-# OPTIONS_GHC -Wno-orphans #-} | |
module Rock where | |
import Data.Dependent.HashMap (DHashMap) | |
import Data.Dependent.HashMap qualified as DHashMap | |
import Data.GADT.Compare (GEq, geq) | |
import Data.GADT.Show (GShow (gshowsPrec)) | |
import Data.Hashable | |
import Data.IORef.Lifted | |
import Data.Kind (Type) | |
import Data.Some | |
import Data.Typeable | |
import Effectful (Dispatch (Static), DispatchOf, Eff, Effect, IOE, Subset, inject, raise, (:>)) | |
import Effectful.Dispatch.Static (SideEffects (NoSideEffects), StaticRep, evalStaticRep, getStaticRep) | |
import Effectful.Timeout (Timeout, timeout) | |
import Unsafe.Coerce (unsafeCoerce) | |
-- * Types | |
type Rules f = forall a es. (f es a -> Eff es a) | |
data Rock (f :: [Effect] -> Type -> Type) :: Effect | |
type instance DispatchOf (Rock f) = Static NoSideEffects | |
newtype instance StaticRep (Rock f) = Rock (forall a es. f es a -> Eff es a) | |
runRock :: Rules f -> Eff (Rock f : es) a -> Eff es a | |
runRock r = evalStaticRep (Rock r) | |
fetch :: (Subset xs es, Rock f :> es) => f xs a -> Eff es a | |
fetch key = do | |
Rock f <- getStaticRep | |
inject (f key) | |
-- * Running tasks | |
data TimeoutQuery f es a where | |
TimeoutQuery :: f es a -> TimeoutQuery f (Timeout : es) (Maybe a) | |
timeoutRules :: Rules f -> Rules (TimeoutQuery f) | |
timeoutRules r (TimeoutQuery k) = do | |
let a = r k | |
timeout 1000000 (inject a) | |
------------------------------------------------------------------------------- | |
-- * Task combinators | |
-- | |
data IOQuery f es a where | |
IOQuery :: f es a -> IOQuery f (IOE : es) a | |
-- | Track the query dependencies of a 'Task' in a 'DHashMap'. | |
track | |
:: forall f es k g a | |
. (GEq k, Hashable (Some k), IOE :> es, Rock f :> es) | |
=> (forall es' a'. f es' a' -> a' -> (k a', g a')) | |
-> Eff es a | |
-> Eff es (a, DHashMap k g) | |
track f = trackM \key value -> pure (f key value) | |
trackM | |
:: forall f es k g a | |
. (GEq k, Hashable (Some k), IOE :> es, Rock f :> es) | |
=> (forall es' a'. f es' a' -> a' -> Eff es' (k a', g a')) | |
-> Eff es a | |
-> Eff es (a, DHashMap k g) | |
trackM f task = do | |
depsVar <- newIORef mempty | |
let | |
record' | |
:: ( (forall a' es'. f es' a' -> Eff es' a') | |
-> (forall a' es'. (IOQuery f) es' a' -> Eff es' a') | |
) | |
record' fetch' (IOQuery key) = do | |
value <- raise $ fetch' key | |
(k, g) <- raise $ f key value | |
atomicModifyIORef depsVar $ (,()) . DHashMap.insert k g | |
pure value | |
result <- transRock record' (raise task) | |
deps <- readIORef depsVar | |
return (result, deps) | |
transRock | |
:: forall f g es a | |
. (Rock f :> es) | |
=> ( (forall a' es'. f es' a' -> Eff es' a') | |
-> (forall a' es'. g es' a' -> Eff es' a') | |
) | |
-> Eff (Rock g : es) a | |
-> Eff es a | |
transRock f m = do | |
Rock r <- getStaticRep @(Rock f) | |
evalStaticRep (Rock (f r)) m | |
-- * Utils | |
-- | A GADT for forgetting the effects required for each key | |
-- The GEq and Eq instances will unsafeCoerce away information on the Effects, | |
-- please don't rely on it. | |
-- | |
-- This is used for using query keys as map keps | |
type HideEffects :: ([Effect] -> Type -> Type) -> Type -> Type | |
data HideEffects f a where | |
HideEffects :: forall f b a. f b a -> HideEffects f a | |
instance (forall es a. Show (f es a)) => GShow (HideEffects f) where | |
gshowsPrec prec (HideEffects x) = | |
showParen (prec > 10) (showString "HideEffects" . showChar ' ' . showsPrec 11 x) | |
instance (forall es. GEq (f es)) => GEq (HideEffects f) where | |
geq (HideEffects (a :: f es a)) (HideEffects (b :: f fs b)) = | |
case geq a (unsafeCoerce b :: f es b) of | |
Nothing -> Nothing | |
Just Refl -> Just Refl | |
instance (forall es. Hashable (f es a)) => Eq (HideEffects f a) where | |
HideEffects a == HideEffects b = a == unsafeCoerce b | |
instance (forall es. Hashable (f es a)) => Hashable (HideEffects f a) where | |
hashWithSalt s (HideEffects f) = hashWithSalt s f | |
instance (forall a. Hashable (f a), GEq f) => Hashable (Some f) where | |
hashWithSalt s (Some f) = hashWithSalt s f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TemplateHaskell #-} | |
{-# HLINT ignore "Use id" #-} | |
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} | |
module Rock.Test where | |
import Data.GADT.Compare.TH | |
import Data.GADT.Show.TH | |
import Data.Hashable | |
import Data.IORef.Lifted | |
import Data.Typeable | |
import Effectful (Eff, IOE, (:>)) | |
import Effectful.Timeout (runTimeout) | |
import Generics.Kind.Derive.Hashable | |
import Generics.Kind.TH (deriveGenericK) | |
import Rock | |
import Rock.Memo | |
data Query es a where | |
QueryInt :: Query '[Rock (MemoQuery Query), IOE, Rock Query2] Int | |
QueryString :: Query '[IOE] String | |
data Query2 es a where | |
Query2Bool :: Bool -> Query2 '[] Bool | |
deriving instance Eq (Query es a) | |
deriving instance Typeable (Query es a) | |
deriving instance Show (Query es a) | |
deriveGenericK ''Query | |
deriveGEq ''Query | |
deriveGCompare ''Query | |
deriveGShow ''Query | |
instance Hashable (Query es a) where | |
hashWithSalt = ghashWithSalt' | |
data Query' es a where | |
QueryInt' :: Query' '[Rock Query', IOE, Rock Query2] Int | |
QueryString' :: Query' '[IOE] String | |
deriving instance Eq (Query' es a) | |
deriving instance Typeable (Query' es a) | |
deriving instance Show (Query' es a) | |
deriveGenericK ''Query' | |
deriveGEq ''Query' | |
deriveGCompare ''Query' | |
deriveGShow ''Query' | |
instance Hashable (Query' es a) where | |
hashWithSalt = ghashWithSalt' | |
instance HasIOE Query' where | |
withIOE = \case | |
QueryInt' -> \x -> x | |
QueryString' -> \x -> x | |
testExplicitMemo :: Rules Query | |
testExplicitMemo = \case | |
QueryInt -> do | |
s <- fetch (MemoQuery QueryString) | |
s' <- fetch (MemoQuery QueryString) | |
pure (length (s <> s')) | |
QueryString -> do | |
sayErr "Querying String" | |
pure "hello" | |
testRules' :: Rules Query' | |
testRules' = \case | |
QueryInt' -> do | |
s <- fetch QueryString' | |
s' <- fetch QueryString' | |
b <- fetch (Query2Bool False) | |
pure (length (if b then s else s')) | |
QueryString' -> do | |
sayErr "Querying String" | |
pure "hello" | |
test2Rules :: Rules Query2 | |
test2Rules = \case | |
Query2Bool b -> pure (not b) | |
test :: (IOE :> es) => Eff es Int | |
test = runRock testRules' . runRock test2Rules $ fetch QueryInt' | |
-- test :: (IOE :> es) => Eff es Int | |
-- test = do | |
-- memMap <- newIORef mempty | |
-- memThreadMap <- newIORef mempty | |
-- runRock testRules | |
-- . runRock (memoiseWithCycleDetection memMap memThreadMap testRules) | |
-- . runRock test2Rules | |
-- $ fetch (MemoQuery QueryInt) | |
-- | |
testImplicitMemo :: (IOE :> es) => Eff es Int | |
testImplicitMemo = do | |
memMap <- newIORef mempty | |
runRock (memoise memMap testRules') | |
. runRock test2Rules | |
$ fetch QueryInt' | |
test2 :: (IOE :> es) => Eff es (Maybe Int) | |
test2 = | |
do | |
runTimeout | |
. runRock testExplicitMemo | |
. runRock (withoutMemoisation testExplicitMemo) | |
. runRock (timeoutRules testExplicitMemo) | |
. runRock test2Rules | |
$ fetch (TimeoutQuery QueryInt) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment