Created
August 29, 2015 03:49
-
-
Save charleso/189cec9eba01079f76ac to your computer and use it in GitHub Desktop.
Playing around with different ways to break out of FoldM
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.ambiata.origami | |
import scalaz._, Scalaz._ | |
trait FoldM[M[_], T, U] { self => | |
type S | |
def start: M[S] | |
def fold(s: S, t: T): M[S] | |
def end(s: S): M[U] | |
/** Stops, but takes over the whole monad */ | |
def take(i: Int)(implicit m: Monad[M]): FoldM[EitherT[M, U, ?], T, U] = new FoldM[EitherT[M, U, ?], T, U] { | |
type S = (Int, self.S) | |
def start = | |
EitherT.right(self.start.map(0 ->)) | |
def fold(s: S, t: T) = | |
if (s._1 < i) EitherT.right(self.fold(s._2, t).map(s._1 + 1 ->)) else EitherT.left(self.end(s._2)) | |
def end(s: S) = | |
EitherT.right(self.end(s._2)) | |
} | |
/** Takes but doesn't stop consuming */ | |
def take2(i: Int)(implicit m: Monad[M]): FoldM[M, T, U] = new FoldM[M, T, U] { | |
type S = (Int, self.S) | |
def start = | |
self.start.map(0 ->) | |
def fold(s: S, t: T) = | |
if (s._1 < i) self.fold(s._2, t).map(s._1 + 1 ->) else s.pure[M] | |
def end(s: S) = | |
self.end(s._2) | |
} | |
def takeWhile(p: T => Boolean)(implicit m: Monad[M]): FoldM[M, T, U] = new FoldM[M, T, U] { | |
type S = Break[self.S] | |
def start = | |
self.start.map(x => Break(x, false)) | |
def fold(s: S, t: T) = { | |
if (!s.stopped) { | |
s.stopped = !p(t) | |
if (!s.stopped) { | |
self.fold(s.value, t).map(x => { s.value = x; s }) | |
} else s.pure[M] | |
} else s.pure[M] | |
} | |
def end(s: S) = | |
self.end(s.value) | |
} | |
def zip[V](f: FoldM[M, T, V])(implicit ap: Apply[M]) = new FoldM[M, T, (U, V)] { | |
type S = (self.S, f.S) | |
def start = ap.tuple2(self.start, f.start) | |
def fold(s: S, t: T) = ap.tuple2(self.fold(s._1, t), f.fold(s._2, t)) | |
def end(s: S) = ap.tuple2(self.end(s._1), f.end(s._2)) | |
} | |
def iterate(l: List[T])(implicit m: Monad[M]): M[U] = | |
l.foldLeft(start) { | |
case (ms, t) => ms.flatMap(s => fold(s, t)) | |
}.flatMap(end) | |
def runBreak(l: List[T])(implicit m: Monad[M]): M[U] = { | |
def go(s: S, l: List[T]): M[S] = | |
// Fuck it, can't be bothered fighting with scala's type system | |
if (s.asInstanceOf[Break[_]].stopped) | |
s.pure[M] | |
else l match { | |
case Nil => s.pure[M] | |
case t :: ts => fold(s, t).flatMap(s => go(s, ts)) | |
} | |
start.flatMap(s => go(s, l).flatMap(end)) | |
} | |
// foldableM.foldMBreak(ft)(self.asInstanceOf[FoldM[M, T, U] { type S = V \/ V }]) | |
def into[N[_]](implicit nat: M ~> N) = new FoldM[N, T, U] { | |
type S = self.S | |
def start = nat(self.start) | |
def fold(s: S, t: T) = nat(self.fold(s, t)) | |
def end(s: S) = nat(self.end(s)) | |
} | |
} | |
/** Mutable, but allows for efficient breaking */ | |
case class Break[A](var value: A, var stopped: Boolean) | |
object FoldMMain { | |
implicit def IN[M[_] : Monad]: Identity ~> M = new (Identity ~> M) { | |
def apply[A](i: Identity[A]): M[A] = Monad[M].point(i.value) | |
} | |
def foldMap[M[_]: Monad, T, U: Monoid](f: T => U): FoldM[M, T, U] = new FoldM[M, T, U] { | |
type S = U | |
def start = Monoid[U].zero.pure[M] | |
def fold(s: S, t: T) = (s |+| f(t)).pure[M] | |
def end(s: S) = s.pure[M] | |
} | |
def main(args: Array[String]): Unit = { | |
type E[A] = EitherT[Identity, Int, A] | |
val m = foldMap[Identity, Int, Int](_ + 1) | |
println(m.iterate(List(1,2,3)).value) | |
val m2 = foldMap[Identity, Int, List[Int]](x => List(x)) | |
println(m2.iterate(List(1,2,3)).value) | |
println(m2.take(2).iterate(List(1,2,3,4,5)).run.value.merge) | |
println(m2.take(10).iterate(List(1,2,3,4,5)).run.value.merge) | |
println(m.take(10).zip(m.take(2)).iterate(List(1,2,3,4,5)).run.value.merge) | |
println(m.into(IN[E]).zip(m.take(2)).iterate(List(1,2,3,4,5)).run.value.merge) | |
println(m.take2(3).zip(m.take2(2)).iterate(List(1,2,3,4,5)).value) | |
println(m.takeWhile(_ < 3).zip(m.take2(2)).iterate(List(1,2,3,4,5)).value) | |
println(m.takeWhile(_ < 3).runBreak(List(1,2,3,4,5)).value) | |
// We can't do this without manually wiring Break through | |
// println(m.takeWhile(_ < 3).zip(m.take2(2)).runBreak(List(1,2,3,4,5)).value) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment