Skip to content

Instantly share code, notes, and snippets.

@aryairani
Last active October 25, 2016 03:38
Show Gist options
  • Save aryairani/82aa2e9cbed3a253be18 to your computer and use it in GitHub Desktop.
Save aryairani/82aa2e9cbed3a253be18 to your computer and use it in GitHub Desktop.
Abstracting over Monads

Abstracting over Monads

We're going to do some computation.

def compute(x: Int, y: Int): Int = x + y

But we want to compute over Option:

def computeOptionPairs(optX: Option[Int], optY: Option[Int]): Option[Int] =
  for {
    x <- optX // flatMap
    y <- optY // map
  } yield compute(x,y)
scala> computeOptionPairs(Some(2), Some(5))
res0: Option[Int] = Some(7)

scala> computeOptionPairs(None, Some(4))
res1: Option[Int] = None

I just remembered, we want to compute over List too:

def computeListPairs(listX: List[Int], listY: List[Int]): List[Int] =
  for {
    x <- listX
    y <- listY
  } yield compute(x,y)
scala> computeListPairs(List(10,20,30), List(5,6))
res2: List[Int] = List(15, 16, 25, 26, 35, 36)

scala> computeListPairs(List.empty, List(5,6))
res3: List[Int] = List()

From the perspective of "don't repeat yourself", it's a shame that we have these two functions, computeOptionPairs and computeListPairs, which are virtually identical, but must be maintained separately.

Making matters worse, someone comes up with a new data type for you to support:

class Lazy[A](_a: => A) {
  lazy val force = _a
  def map[B](f: A => B) = Lazy(f(force))
  def flatMap[B](f: A => Lazy[B]) = Lazy(f(force).force)
}; object Lazy {
  def apply[A](a: => A) = new Lazy(a)
}

def computeLazyPairs(lazyX: Lazy[Int], lazyY: Lazy[Int]): Lazy[Int] =
  for {
    x <- lazyX
    y <- lazyY
  } yield compute(x,y)

val lazyResult =
  computeLazyPairs(
    Lazy { println("x");  6 },
    Lazy { println("y");  7 }
  )
scala> lazyResult.force
x
y
res4: Int = 13

Now we have three nearly-identical functions to maintain. 😩

Would be nice to write something generic that handles all three cases, but this won't work:

def computePairs[Fancy[_]](fancyX: Fancy[Int], fancyY: Fancy[Int]): Fancy[Int] =
   for {
     x <- fancyX
     y <- fancyY
   } yield compute(x,y)

A Monad abstraction lets us do exactly this, it abstracts over the map and flatMap methods, which are all the computePairs function needs. We define the interface:

import language.higherKinds

trait SimpleMonad[Fancy[_]] {
  def     map[A,B](fa: Fancy[A], f: A => B): Fancy[B]
  def flatMap[A,B](fa: Fancy[A], f: A => Fancy[B]): Fancy[B]
}

and instances for our three data types:

implicit val optionMonad = new SimpleMonad[Option] {
  def     map[A, B](myOption: Option[A], f: A => B)         = myOption.map(f)
  def flatMap[A, B](myOption: Option[A], f: A => Option[B]) = myOption.flatMap(f)
}

implicit val listMonad = new SimpleMonad[List] {
  def     map[A, B](myList: List[A], f: A => B)       = myList.map(f)
  def flatMap[A, B](myList: List[A], f: A => List[B]) = myList.flatMap(f)
}

implicit val lazyMonad = new SimpleMonad[Lazy] {
  def     map[A, B](myLazy: Lazy[A], f: A => B)       = myLazy.map(f)
  def flatMap[A, B](myLazy: Lazy[A], f: A => Lazy[B]) = myLazy.flatMap(f)
}

Now the boilerplate per supported type is constant (two short functions each), not an application-specific amount. This boilerplate will normally buried away in the library.

We can implicitly add methods needed to make for-comprehensions work, to any type with a SimpleMonad instance.

implicit class SimpleMonadSyntax[M[_],A](m: M[A])(implicit monadImpl: SimpleMonad[M]) {
  def     map[B](f: A => B)   : M[B] = monadImpl.map(m, f)
  def flatMap[B](f: A => M[B]): M[B] = monadImpl.flatMap(m, f)
}

Now we actually can write that idealized code, and use it for all three container types:

def computePairs[F[_]:SimpleMonad](fancyX: F[Int], fancyY: F[Int]): F[Int] =
  for {
    x <- fancyX
    y <- fancyY
  } yield compute(x,y)
scala> computePairs[Option](Some(2),  Some(5))
res5: Option[Int] = Some(7)

scala> computePairs[Option](None,     Some(4))
res6: Option[Int] = None
scala> computePairs[List](List(10,20,30), List(5,6))
res7: List[Int] = List(15, 16, 25, 26, 35, 36)

scala> computePairs[List](List(),         List(5,6))
res8: List[Int] = List()
scala> val lazyResult =
     |   computePairs(
     |     Lazy { println("x");  6 },
     |     Lazy { println("y");  7 }
     |   )
lazyResult: Lazy[Int] = Lazy@2a2da088

scala> lazyResult.force
x
y
res9: Int = 13

The End

P.S. The real Monad has one more operation that wasn't needed for our example, a construction operation.

trait Monad[M[_]] {
  def map[A,B](fa: M[A], f: A => B): M[B]
  def flatMap[A,B](fa: M[A], f: A => M[B]): M[B]
  def construct[A](a: => A): M[A]
}

implicit val optionMonad = new Monad[Option] {
  def map[A, B](myOption: Option[A], f: A => B)             = myOption.map(f)
  def flatMap[A, B](myOption: Option[A], f: A => Option[B]) = myOption.flatMap(f)
  def construct[A](a: => A)                                 = Option(a)
}

implicit val listMonad = new Monad[List] {
  def map[A, B](myList: List[A], f: A => B)           = myList.map(f)
  def flatMap[A, B](myList: List[A], f: A => List[B]) = myList.flatMap(f)
  def construct[A](a: => A)                           = List(a)
}

implicit val lazyMonad = new Monad[Lazy] {
  def map[A, B](myLazy: Lazy[A], f: A => B)           = myLazy.map(f)
  def flatMap[A, B](myLazy: Lazy[A], f: A => Lazy[B]) = myLazy.flatMap(f)
  def construct[A](a: => A)                           = Lazy(a)
}
@lancegatlin
Copy link

Nice! Seems like the start for a great blog entry =]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment