-
-
Save LSLeary/4484fb6bc4d96e59092d48592c162b9f to your computer and use it in GitHub Desktop.
Checked exceptions implemented by grading IO with the set of exceptions an action may throw.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds, TypeFamilies, UndecidableInstances, RoleAnnotations | |
, QuantifiedConstraints, RebindableSyntax, BlockArguments | |
, RequiredTypeArguments | |
#-} | |
module Graded ( | |
GradedAppl(..), (<*>), (<*), (*>), | |
GradedAlt(..), | |
GradedMonad(..), (>>), | |
Throws, | |
throwsToIO, ioToThrows, | |
relaxl, relaxr, | |
type (\/), idemp, | |
type (|>), type (\\), | |
throw, catch, try, | |
mask, uninterruptibleMask, | |
mask_, uninterruptibleMask_, | |
onException, finally, bracket, | |
) where | |
import Prelude qualified as Base | |
import Prelude hiding (Applicative(..), Monad(..)) | |
import Unsafe.Coerce (unsafeCoerce) | |
import Data.Kind (Type) | |
import Data.Type.Equality ((:~:)(..)) | |
import Data.Coerce (coerce) | |
import Data.Functor (($>)) | |
import Control.Exception qualified as Base | |
import Control.Exception | |
( Exception, SomeException(..), SomeAsyncException | |
, ArithException(DivideByZero), PatternMatchFail(..), NonTermination(..) | |
) | |
class (forall g. Functor (f g)) => GradedAppl (f :: k -> Type -> Type) where | |
type E f :: k | |
type Ap f (g1 :: k) (g2 :: k) :: k | |
leftId :: forall f' g -> f' ~ f => Ap f (E f) g :~: g | |
rightId :: forall f' g -> f' ~ f => Ap f g (E f) :~: g | |
assocAp :: forall f' g1 g2 g3 -> f' ~ f | |
=> Ap f (Ap f g1 g2) g3 :~: Ap f g1 (Ap f g2 g3) | |
pure :: a -> f (E f) a | |
liftA2 :: (a -> b -> c) -> f g1 a -> f g2 b -> f (Ap f g1 g2) c | |
(<*>) :: GradedAppl f => f g1 (a -> b) -> f g2 a -> f (Ap f g1 g2) b | |
(<*>) = liftA2 ($) | |
(<*) :: GradedAppl f => f g1 a -> f g2 b -> f (Ap f g1 g2) a | |
(<*) = liftA2 \x _ -> x | |
(*>) :: GradedAppl f => f g1 a -> f g2 b -> f (Ap f g1 g2) b | |
(*>) = liftA2 \_ y -> y | |
class GradedAppl f => GradedAlt f where | |
type Alt f (g1 :: k) (g2 :: k) :: k | |
assocAlt :: forall f' g1 g2 g3 -> f' ~ f | |
=> Alt f (Alt f g1 g2) g3 :~: Alt f g1 (Alt f g2 g3) | |
(<|>) :: f g1 a -> f g2 a -> f (Alt f g1 g2) a | |
class GradedAppl m => GradedMonad (m :: k -> Type -> Type) where | |
(>>=) :: m g1 a -> (a -> m g2 b) -> m (Ap m g1 g2) b | |
(>>) :: GradedMonad m => m g1 a -> m g2 b -> m (Ap m g1 g2) b | |
ma >> mb = ma >>= \_ -> mb | |
type role Throws nominal representational | |
newtype Throws (es :: [Type]) a = UnsafeThrows (IO a) | |
deriving Functor | |
throwsToIO :: Throws es a -> IO a | |
throwsToIO = coerce | |
ioToThrows :: IO a -> Throws '[SomeException] a | |
ioToThrows = coerce | |
relaxl :: forall es1 -> Throws es2 a -> Throws (es1 \/ es2) a | |
_ `relaxl` t = coerce t | |
relaxr :: Throws es1 a -> forall es2 -> Throws (es1 \/ es2) a | |
t `relaxr` _ = coerce t | |
instance GradedAppl Throws where | |
type E Throws = '[] | |
type Ap Throws es1 es2 = es1 \/ es2 | |
leftId _ _ = unsafeCoerce Refl | |
rightId _ _ = Refl | |
assocAp _ _ _ _ = unsafeCoerce Refl | |
pure x = coerce (Base.pure @IO x) | |
liftA2 f = coerce (Base.liftA2 @IO f) | |
instance GradedAlt Throws where | |
type Alt Throws es1 es2 = es2 | |
assocAlt _ _ _ _ = Refl | |
ta1 <|> ta2 = UnsafeThrows do | |
throwsToIO ta1 `Base.catch` \SomeException{} -> do | |
throwsToIO ta2 | |
instance GradedMonad Throws where | |
ta >>= k = UnsafeThrows (throwsToIO ta Base.>>= throwsToIO . k) | |
type family (xs :: [k]) \/ (ys :: [k]) :: [k] where | |
es \/ '[ ] = es | |
ls \/ (r ': rs) = ls |> r \/ rs | |
infixr 3 \/ | |
idemp :: forall es -> (es \/ es) :~: es | |
idemp _ = unsafeCoerce Refl | |
type family (es :: [k]) |> (e :: k) :: [k] where | |
(e ': es) |> e = e ': es | |
(e ': es) |> x = e ': (es |> x) | |
'[ ] |> x = '[x] | |
infixl 4 |> | |
type family (es :: [k]) \\ (e :: k) :: [k] where | |
(e ': es) \\ e = es | |
(e ': es) \\ x = e ': (es \\ x) | |
'[ ] \\ x = '[] | |
infixl 4 \\ | |
throw :: Exception e => e -> Throws '[e] a | |
throw = UnsafeThrows . Base.throwIO | |
catch | |
:: Exception e | |
=> Throws es1 a | |
-> (e -> Throws es2 a) | |
-> Throws (es1 \\ e \/ es2) a | |
catch ta h = UnsafeThrows do | |
throwsToIO ta `Base.catch` \e -> throwsToIO (h e) | |
try | |
:: Exception e | |
=> Throws es a | |
-> Throws (es \\ e) (Either e a) | |
try = UnsafeThrows . Base.try . throwsToIO | |
mask | |
:: ((forall a es1. Throws es1 a -> Throws es1 a) -> Throws es2 b) | |
-> Throws es2 b | |
mask f = UnsafeThrows do | |
Base.mask \restoreIO -> throwsToIO do | |
f (UnsafeThrows . restoreIO . throwsToIO) | |
uninterruptibleMask | |
:: ( (forall a es1. Throws (es1 \\ SomeAsyncException) a -> Throws es1 a) | |
-> Throws (es2 \\ SomeAsyncException) b | |
) -> Throws es2 b | |
uninterruptibleMask f = UnsafeThrows do | |
Base.uninterruptibleMask \restoreIO -> throwsToIO do | |
f (UnsafeThrows . restoreIO . throwsToIO) | |
mask_ :: Throws es b -> Throws es b | |
mask_ act = mask \_ -> act | |
uninterruptibleMask_ | |
:: Throws (es2 \\ SomeAsyncException) b | |
-> Throws es2 b | |
uninterruptibleMask_ act = uninterruptibleMask \_ -> act | |
onException :: Throws es1 a -> Throws es2 b -> Throws (es1 \/ es2) a | |
onException ta tb = UnsafeThrows do | |
throwsToIO ta `Base.catch` \(SomeException e) -> | |
throwsToIO tb Base.>> Base.throwIO e | |
bracket | |
:: forall es1 es2 es3 a b c | |
. Throws es1 a | |
-> (a -> Throws es3 c) | |
-> (a -> Throws es2 b) | |
-> Throws (es1 \/ es2 \/ es3) b | |
bracket initialise finalise body = case assocAp Throws es2 es3 es3 of | |
Refl -> case idemp es3 of | |
Refl -> mask \restore -> do | |
resource <- initialise | |
result <- restore (body resource) `onException` finalise resource | |
finalise resource $> result | |
finally | |
:: forall es1 es2 a b | |
. Throws es1 a -> Throws es2 b -> Throws (es1 \/ es2) a | |
act `finally` finalise = case leftId Throws (es1 \/ es2) of | |
Refl -> bracket (pure ()) (const finalise) (const act) | |
_test0 :: Throws '[] () | |
_test0 = pure () | |
_test1 :: Throws '[ArithException] a | |
_test1 = throw DivideByZero | |
_test2 :: Throws '[NonTermination] a | |
_test2 = throw NonTermination | |
_test3 :: Throws '[PatternMatchFail] a | |
_test3 = throw (PatternMatchFail "") | |
_test4 :: Throws [ArithException, NonTermination, PatternMatchFail] (a, b) | |
_test4 = do | |
_test1 | |
a <- _test2 | |
b <- _test3 | |
pure (a, b) | |
_test5 :: Throws [ArithException, PatternMatchFail] (a, b) | |
_test5 = _test4 `catch` \NonTermination -> _test1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment