MonadError in cats has the following laws:
- L1:
handle(raise(e))(f) == f(e) - L2:
pure(a) >>= (_ => raise(e)) == raise(e) - L3:
raise(e) >>= (_ => point(a)) == raise(e)
And it is defined roughly like this:
trait MonadError[F[_], E] {
def raise[A](e: E): F[A]
def handle[A](fa: F[A])(f: E => F[A]): F[A]
}Here are some weird monads that satisfy all laws of MonadError:
sealed abstract class Weird[A]
final case class Pure[A](a: A) extends Weird[A]
final case class Error[A](a: NonEmptyList[E]) extends Weird[A]
implicit val weirdMonadError = new MonadError[Weird, E] {
def raiseError[A](e: E): F[A] = Error(NonEmptyList(e))
def handle[A](fa: F[A])(f: E => F[A]): F[A] = fa match {
case Pure(x) => x
case Error(OneAnd(x, _)) => f(x)
}
}I propose a slightly modified MonadError that outlaws aforementioned instance:
trait MonadError[F[_], E] {
def raise[A](e: E): F[A]
def catch[A](fa: F[A]): F[Option[E]]
def handle[A](fa: F[A])(f: E => F[A]): F[A] = catch(fa).flatMap(_.fold(fa, e => f(e)))
}with the following laws:
- L2:
pure(a) >>= (_ => raise(e)) == raise(e) - L3:
raise(e) >>= (_ => pure(a)) == raise(e) - L4:
catch(fa)is eitherpure(None)orpure(Some(e)) - L5:
catch(raise(e)) == pure(Some(e)) - L6: if
catch(fa) == pure(Some(e))thenraise(e) == fa
From L4-L6 we can conclude that raise(e) is the only exceptional F[A].
L1 follows from L5:
f(e)
= handle(raise(e))(f)
= catch(raise(e)).flatMap(_.fold(raise(e))(f))
= pure(Some(e)).flatMap(_.fold(raise(e))(f))
= Some(e).fold(raise(e))(f)
= f(e)New theorem handle(pure(a))(f) == pure(a) follows from:
pure(a)
= handle(pure(a))(f)
= catch(pure(a)).flatMap(_.fold(pure(a))(f))
= pure(None).flatMap(_.fold(pure(a))(f))
= None.fold(pure(a))(f)
= pure(a)The Weird satisfies all laws except L5. I think it is worth outlawing such instances.
L2-L6 seem to be enough to prove that if F[_] :MonadError, then F[A] can always be factored into H[A] = P[Xor[E, G[A]]], where G[A] is F[A] such that catch(fa) == None and P[A] is pure(a):
type G[A] = F[A] // such that catch(ga) == None
type P[A] = F[A] // such that pa == pure(a)
type H[A] = P[Xor[E, G[A]]]
def from(fa: F[A]): H[A] = catch(fa).flatMap(_.fold(Right(fa), e => Left(e)))
def to(ga: H[A]): F[A] = ga.flatMap(_.fold(raise, identity))
to(from(fa))
= catch(fa).flatMap(_.fold(Right(fa), e => Left(e))).flatMap(_.fold(raise, identity))
= catch(fa).flatMap(_.fold(Right(fa), e => Left(e)).fold(raise, identity))
= catch(fa).flatMap(_.fold(fa, e => raise(e)))
// L5!
= fa
from(to(Left(e)))
= ...
= Left(e)
from(to(Right(fa)))
= ...
= Right(e)