Skip to content

Instantly share code, notes, and snippets.

@lbialy
Created April 11, 2022 15:57
Show Gist options
  • Save lbialy/af4864ec57ab7f60a3f2e31f85b9ff81 to your computer and use it in GitHub Desktop.
Save lbialy/af4864ec57ab7f60a3f2e31f85b9ff81 to your computer and use it in GitHub Desktop.
Monadic laws in Scala 3
//> using scala "3.1.1"
// let's start with basic functor because all monads in scala have to be functors to work
// within a for-comprehension (because it desugars to flatMap + map chain)
// F[_] represents any concrete type parameterised with a single type (a type constructor), ie List, Option
trait Functor[F[_]]:
def map[A, B](fa: F[A])(f: A => B): F[B]
// define a correct interface for a Monad that extends that Functor, Monad is just
// two functions: pure (return in hs) and flatMap (bind in hs)
trait Monad[F[_]] extends Functor[F]:
def pure[A](a: A): F[A]
def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B]
// we also need a summoner object to use this stuff easily
object Monad:
def apply[F[_]: Monad]: Monad[F] = summon[Monad[F]]
// for the sake of making tests easier let's also define a Comonad to make stuff easy
// Comonad is basically something that you can get stuff out of
trait Comonad[F[_]]:
def get[A](fa: F[A]): A
// another summoner object
object Comonad:
def apply[F[_]: Comonad]: Comonad[F] = summon[Comonad[F]]
// now some extension methods to be able to use our abstractions directly
extension [A](a: A)
def pure[F[_]: Monad]: F[A] = Monad[F].pure(a)
extension [F[_]: Monad: Comonad, A](fa: F[A])
def flatMap[B](f: A => F[B]): F[B] = Monad[F].flatMap(fa)(f)
def map[B](f: A => B): F[B] = Monad[F].map(fa)(f)
def get: A = Comonad[F].get(fa)
// now let's define the completely abstract Laws class to test monadic laws
// we are going to use interfaces only to reuse it for both eager and lazy monads
class Laws[F[_]: Monad: Comonad]:
def run: Unit =
val f: Int => F[Int] = int => pure(int * 2)
val a: Int = 2
val m: F[Int] = pure(a)
assert( // left identity: pure(a).flatMap(f) === f(a)
get(pure(a).flatMap(f)) == get(f(a))
)
assert( // right identity: m.flatMap(pure) === m
get(m.flatMap(pure)) == get(m)
)
assert( // associativity: m.flatMap(f).flatMap(f) === m.flatMap(x => f(x).flatMap(f))
get((m).flatMap(f).flatMap(f)) == get(m.flatMap(x => f(x).flatMap(f)))
)
// ok, so we can now test monadic laws using abstract tests, let's define concrete monadic types to test
case class LazyMonad[A](private val a: () => A):
def flatMap[B](f: A => LazyMonad[B]): LazyMonad[B] = f(a())
def map[B](f: A => B): LazyMonad[B] = LazyMonad(() => f(a()))
def unsafeRun: A = a()
case class StrictMonad[A](private val a: A):
def flatMap[B](f: A => StrictMonad[B]): StrictMonad[B] = f(a)
def map[B](f: A => B): StrictMonad[B] = StrictMonad(f(a))
def run: A = a
// ok, now we need instances of for our monadic concrete types to make them Monad/Comonad typeclass members
implicit object LazyMonadInstance extends Monad[LazyMonad] with Comonad[LazyMonad]:
def pure[A](a: A): LazyMonad[A] = LazyMonad(() => a)
def flatMap[A, B](fa: LazyMonad[A])(f: A => LazyMonad[B]): LazyMonad[B] = fa.flatMap(f)
def map[A, B](fa: LazyMonad[A])(f: A => B): LazyMonad[B] = fa.map(f)
def get[A](fa: LazyMonad[A]): A = fa.unsafeRun
implicit object StrictMonadInstance extends Monad[StrictMonad] with Comonad[StrictMonad]:
def pure[A](a: A): StrictMonad[A] = StrictMonad(a)
def flatMap[A, B](fa: StrictMonad[A])(f: A => StrictMonad[B]): StrictMonad[B] = fa.flatMap(f)
def map[A, B](fa: StrictMonad[A])(f: A => B): StrictMonad[B] = fa.map(f)
def get[A](fa: StrictMonad[A]): A = fa.run
// ok, let's test stuff
@main def main(): Unit =
new Laws[LazyMonad].run
new Laws[StrictMonad].run
// ok, maybe some other stuff from stdlib?
implicit object OptionInstance extends Monad[Option] with Comonad[Option]:
def pure[A](a: A): Option[A] = Some(a)
def flatMap[A, B](fa: Option[A])(f: A => Option[B]): Option[B] = fa.flatMap(f)
def map[A, B](fa: Option[A])(f: A => B): Option[B] = fa.map(f)
def get[A](fa: Option[A]): A = fa.get // very unsafe, don't define Comonad for Option in normal setting, it's not possible
new Laws[Option].run
// let's get radical
import scala.concurrent.*, duration.*
implicit object FutureInstance extends Monad[Future] with Comonad[Future]:
import ExecutionContext.Implicits.global
def pure[A](a: A): Future[A] = Future.successful(a)
def flatMap[A, B](fa: Future[A])(f: A => Future[B]): Future[B] = fa.flatMap(f)
def map[A, B](fa: Future[A])(f: A => B): Future[B] = fa.map(f)
def get[A](fa: Future[A]): A = Await.result(fa, Duration.Inf) // very unsafe, don't define Comonad for Future as it blocks the current thread
new Laws[Future].run // but that is not really true, Future is a pseudomonad - check out this answer: https://stackoverflow.com/a/27467037
println("all the stuff works just fine")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment