Skip to content

Instantly share code, notes, and snippets.

@rjsvaljean
Last active March 25, 2017 17:06
Show Gist options
  • Select an option

  • Save rjsvaljean/f6845832b6451e29a2e06d0e832da973 to your computer and use it in GitHub Desktop.

Select an option

Save rjsvaljean/f6845832b6451e29a2e06d0e832da973 to your computer and use it in GitHub Desktop.
Recursion Schemes from Scratch
name := "typelevel-recursion-schemes-talk"
scalaVersion := "2.11.8"
scalacOptions ++= Seq("-feature", "-language:higherKinds")
package rxvl
trait Functor[F[_]] {
def map[A, B](f: => A => B): F[A] => F[B] // NOTE: lazy f is necessary
}
case class Fix[F[_]](unFix: F[Fix[F]])
object fmap {
def apply[F[_], A, B](f: => A => B)(implicit F: Functor[F]): F[A] => F[B] = F.map(f)
}
object unFix { def apply[F[_]]: Fix[F] => F[Fix[F]] = _.unFix }
sealed trait ListF[+B] extends Product with Serializable
object ListF {
implicit val functor: Functor[ListF] = new Functor[ListF] {
def map[A, B](f: => A => B): ListF[A] => ListF[B] = {
case NilF => NilF
case ConsF(a, r) => ConsF(a, f(r))
}
}
}
case object NilF extends ListF[Nothing]
case class ConsF[B](a: Int, r: B) extends ListF[B]
object MyList {
type T = Fix[ListF]
def nil: T = Fix[ListF](NilF: ListF[Fix[ListF]])
def cons(h: Int, t: T): T = Fix[ListF](ConsF(h, t): ListF[Fix[ListF]])
def apply(as: Int*): MyList.T = as.foldRight(nil)(MyList.cons)
def toList: MyList.T => List[Int] = schemes.cata[ListF, List[Int]] {
case NilF => Nil
case ConsF(h, t) => h :: t
}
}
sealed trait TreeF[R]
case class LeafF[R](a: Int) extends TreeF[R]
case class NodeF[R](r: R, l: R) extends TreeF[R]
object TreeF {
implicit val functor: Functor[TreeF] = new Functor[TreeF] {
def map[A, B](f: => (A) => B): (TreeF[A]) => TreeF[B] = {
case LeafF(a) => LeafF(a)
case NodeF(l, r) => NodeF(f(l), f(r))
}
}
}
object schemes {
private def both[B, C1, C2](f: B => C1, g: B => C2): B => (C1, C2) = {
(b: B) => (f(b), g(b))
}
private def either[B1, B2, C](f: B1 => C, g: B2 => C): Either[B1, B2] => C = {
_.fold(f, g)
}
def cata[F[_]: Functor, B](alg: F[B] => B): Fix[F] => B =
unFix[F] andThen fmap[F, Fix[F], B](cata(alg)) andThen alg
def ana[F[_]: Functor, B](coalg: B => F[B]): B => Fix[F]=
coalg andThen fmap[F, B, Fix[F]](ana(coalg)) andThen Fix[F]
def para[F[_]: Functor, A](alg: F[(A, Fix[F])] => A): Fix[F] => A = {
cata[F, (A, Fix[F])](both(alg, fmap[F, (A, Fix[F]), Fix[F]](_._2) andThen Fix[F])) andThen(_._1)
}
def apo[F[_]: Functor, A](coalg: A => F[Either[A, Fix[F]]]): A => Fix[F] = {
val left: (A) => Either[A, Fix[F]] = Left(_)
val right: (Fix[F]) => Either[A, Fix[F]] = Right(_)
left andThen ana[F, Either[A, Fix[F]]](either(coalg, (unFix[F](_)) andThen fmap[F, Fix[F], Either[A, Fix[F]]](right)))
}
def hylo[F[_]: Functor, A, B](coalg: B => F[B], alg: F[A] => A): B => A = {
ana(coalg) andThen cata(alg)
}
}
object plus {
def unapply(i: Int): Option[(Int, Int)] = Some((1, i - 1))
}
object Main {
import schemes._
val sumAlg: ListF[Int] => Int = {
case NilF => 0
case ConsF(h, sumSoFar) => h + sumSoFar
}
val sum = cata(sumAlg)
val cosumAlg: Int => ListF[Int] = {
case 0 => NilF
case plus(h, remainingSum) => ConsF(h, remainingSum)
}
val cosum = ana(cosumAlg)
val sumParaAlg: ListF[(Int, Fix[ListF])] => Int = {
case NilF => 0
case ConsF(h, (sumSoFar, listSoFar)) =>
println(s"Trace: $listSoFar : $sumSoFar")
h + sumSoFar
}
val traceSum = para(sumParaAlg)
val earlyTerminateCoAlg: Int => ListF[Either[Int, Fix[ListF]]] = {
case 0 => NilF
case n if n == 2 => ConsF(2, Right(MyList.nil))
case plus(h, remainingSum) => ConsF(h, Left(remainingSum))
}
val earlyTerminateCoSum = apo(earlyTerminateCoAlg)
def merge: TreeF[List[Int]] => List[Int] = {
def mergeLists: (List[Int], List[Int]) => List[Int] = {
val coAlg: ((List[Int], List[Int])) => ListF[(List[Int], List[Int])] = {
case (Nil, Nil) => NilF
case (Nil, y :: ys) => ConsF(y, (Nil, ys))
case (x :: xs, Nil) => ConsF(x, (xs, Nil))
case (x :: xs, y :: ys) if x <= y => ConsF(x, (xs, y :: ys))
case (x :: xs, y :: ys) if x > y => ConsF(y, (x :: xs, ys))
}
Function.untupled(ana[ListF, (List[Int], List[Int])](coAlg) andThen MyList.toList)
}
{
case LeafF(x) => List(x)
case NodeF(xs, ys) => mergeLists(xs, ys)
}
}
def unflatten: List[Int] => TreeF[List[Int]] = {
object half {
def unapply[A](as: List[A]): Option[(List[A], List[A])] = Some(as.splitAt(as.length / 2))
}
{
case x :: Nil => LeafF[List[Int]](x)
case half(xs, ys) => NodeF[List[Int]](xs, ys)
}
}
val mSort = hylo[TreeF, List[Int], List[Int]](unflatten, merge)
def main(args: Array[String]): Unit = {
List(
sum(MyList(1,2,3)),
traceSum(MyList(1,2,3)),
MyList.toList(cosum(6)),
MyList.toList(earlyTerminateCoSum(6)),
mSort(List(4,3,2,1))
).foreach(println)
// > run
// [info] Running rxvl.Main
// Trace: Fix(NilF) : 0
// Trace: Fix(ConsF(3,Fix(NilF))) : 3
// Trace: Fix(ConsF(2,Fix(ConsF(3,Fix(NilF))))) : 5
// 6
// 6
// List(1, 1, 1, 1, 1, 1)
// List(1, 1, 1, 1, 2)
// List(1, 2, 3, 4)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment