Skip to content

Instantly share code, notes, and snippets.

@snipsnipsnip
Last active August 24, 2019 15:24
Show Gist options
  • Save snipsnipsnip/112597 to your computer and use it in GitHub Desktop.
Save snipsnipsnip/112597 to your computer and use it in GitHub Desktop.
Yield monad
{-# LANGUAGE FlexibleInstances, MultiParamTypeClasses, UndecidableInstances, FunctionalDependencies #-}
module YieldMonad
( MonadYield (..)
, Yield (..)
, YieldT (..)
, traceYieldT
, module Control.Monad.Trans
) where
import Control.Monad.Error
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Monad.Trans
class Monad m => MonadYield y m | m -> y where
yield :: y -> m ()
newtype Yield y a = Yield
{ runYield :: Either (y, Yield y a) a
} deriving (Eq, Show)
instance MonadYield y (Yield y) where
yield y = Yield $ Left (y, return ())
instance Functor (Yield y) where
fmap = liftM
instance Monad (Yield y) where
return = Yield . Right
Yield (Left (y,p)) >>= f = Yield $ Left (y, p >>= f)
Yield (Right a) >>= f = f a
newtype YieldT y m a = YieldT
{ runYieldT :: m (Either (y, YieldT y m a) a)
}
traceYieldT :: (MonadIO m, Show y, Show a) => YieldT y m a -> m ()
traceYieldT m = loop (1 :: Int) m
where
loop n m = do
r <- runYieldT m
liftIO $ putStrLn $ "traceYieldT(" ++ show n ++ "): " ++ showYieldT r
case r of
Left (_,p) -> loop (n + 1) p
_ -> return ()
showYieldT (Left (a,p)) = "Left (" ++ show a ++ ", #<YieldT>)"
showYieldT (Right a) = "Right " ++ show a
instance Monad m => MonadYield y (YieldT y m) where
yield y = YieldT $ return $ Left (y, return ())
instance (Monad m) => Functor (YieldT y m) where
fmap = liftM
instance MonadTrans (YieldT y) where
lift = YieldT . liftM Right
instance (MonadIO m) => MonadIO (YieldT y m) where
liftIO = lift . liftIO
instance (Monad m) => Monad (YieldT y m) where
return = YieldT . return . Right
t >>= f = YieldT $ do
p <- runYieldT t
case p of
Right a -> runYieldT $ f a
Left (y,p) -> return $ Left $ (y, p >>= f)
-- Instances for other monad classes
instance (MonadState s m) => MonadState s (YieldT y m) where
get = lift get
put = lift . put
instance (MonadReader r m) => MonadReader r (YieldT y m) where
ask = lift ask
local f m = YieldT $ local f $ runYieldT m
instance (Functor m, MonadWriter w m) => MonadWriter w (YieldT y m) where
tell = lift . tell
listen m = YieldT $ do
(p, w) <- listen $ runYieldT m
return $ case p of
Left (y,p) -> Left (y, liftM (\a -> (a, w)) p)
Right a -> Right (a, w)
pass m = YieldT $ pass $ do
p <- runYieldT m
return $ case p of
Left (y,p) -> (Left (y, pass p), id)
Right (a, f) -> (Right a, f)
instance (MonadError e m) => MonadError e (YieldT y m) where
throwError = lift . throwError
catchError m f = YieldT $ catchError (runYieldT m) (runYieldT . f)
-- Instances for other transformers
instance (MonadYield y m) => MonadYield y (ReaderT s m) where
yield = lift . yield
instance (MonadYield y m) => MonadYield y (StateT s m) where
yield = lift . yield
instance (MonadYield y m, Monoid w) => MonadYield y (WriterT w m) where
yield = lift . yield
instance (MonadYield y m, Error e) => MonadYield y (ErrorT e m) where
yield = lift . yield
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment