Skip to content

Instantly share code, notes, and snippets.

@erikerlandson
Last active March 18, 2020 17:49
Show Gist options
  • Save erikerlandson/f1726bf8a9a097edac1c18f7af8d732d to your computer and use it in GitHub Desktop.
Save erikerlandson/f1726bf8a9a097edac1c18f7af8d732d to your computer and use it in GitHub Desktop.
Monads for using break and continue in scala for comprehensions
object demo {
import scala.language.implicitConversions
import scala.language.higherKinds
import scala.collection.SeqLike
import scala.collection.TraversableLike
import scala.collection.generic.CanBuildFrom
import scala.util.{ Try, Success, Failure }
import scala.annotation.tailrec
trait Label
sealed trait Control extends Exception
case class Break(label: Label) extends Control
case class Continue(label: Label) extends Control
def break(lab: Label): Unit = { throw Break(lab) }
def continue(lab: Label): Unit = { throw Continue(lab) }
class Breakable[A](stream: () => Stream[Try[A]], label: Label, p: A => Boolean) {
import Breakable._
def toStream: Stream[A] = stream().map { _.get }
def toIterator: Iterator[A] = toStream.iterator
private def rawStream = stream()
def map[B](f: A => B): Breakable[B] = {
new Breakable(() => trMap(f, stream()), noLabel, pTrue[B])
}
private def trMap[B](f: A => B, s: Stream[Try[A]]): Stream[Try[B]] = {
if (s.isEmpty) Stream.empty[Try[B]] else s.head.map(f) match {
case h @ Success(_) => h #:: trMap(f, s.tail)
case Failure(Continue(`label`)) => trMap(f, s.tail)
case Failure(Break(`label`)) => Stream.empty[Try[B]]
case h @ Failure(_: Control) => Stream(h)
case Failure(t: Throwable) => throw(t)
}
}
def flatMap[B](f: A => Breakable[B]): Breakable[B] = {
new Breakable(() => trFlatMap(f, stream()), noLabel, pTrue[B])
}
private def trFlatMap[B](f: A => Breakable[B], s: Stream[Try[A]]): Stream[Try[B]] = {
if (s.isEmpty) Stream.empty[Try[B]] else s.head.map(f) match {
case Success(b) => trFlatMapContinue(b.rawStream, () => trFlatMap(f, s.tail))
case Failure(Continue(`label`)) => trFlatMap(f, s.tail)
case Failure(Break(`label`)) => Stream.empty[Try[B]]
case Failure(c: Control) => Stream(Failure(c))
case Failure(t: Throwable) => throw(t)
}
}
private def trFlatMapContinue[B](s: Stream[Try[B]], cont: () => Stream[Try[B]]): Stream[Try[B]] = {
if (s.isEmpty) cont() else s.head match {
case h @ Success(_) => h #:: trFlatMapContinue(s.tail, cont)
case Failure(Continue(`label`)) => cont()
case Failure(Break(`label`)) => Stream.empty[Try[B]]
case h @ Failure(_: Control) => Stream(h)
case Failure(t: Throwable) => throw(t)
}
}
def foreach[U](f: A => U): Unit = trForeach(f, stream())
@tailrec private def trForeach[U](f: A => U, s: Stream[Try[A]]): Unit = {
if (!s.isEmpty) s.head.map(f) match {
case Success(_) => trForeach(f, s.tail)
case Failure(Continue(`label`)) => trForeach(f, s.tail)
case Failure(Break(`label`)) => ()
case Failure(ctrl) => throw ctrl
}
}
def withFilter(q: A => Boolean): Breakable[A] =
new Breakable(stream, label, a => p(a) && q(a))
}
object Breakable {
object noLabel extends Label
def pTrue[A]: A => Boolean = (_: A) => true
def apply[A](t: => Seq[A]) = new Breakable(() => t.toStream.map(Success(_)), noLabel, pTrue[A])
def apply[A](t: => Seq[A], lab: Label) = new Breakable(() => t.toStream.map(Success(_)), lab, pTrue[A])
}
object foo extends Label
object goo extends Label
object moo extends Label
def toInf(v: Int): Stream[Int] = v #:: toInf(v + 1)
def test1 = for {
x <- Breakable(Vector(1, 2, 3), foo)
} yield (x + 1)
def test2 = for {
x <- Breakable(toInf(0), foo)
} yield (x + 1)
def test3 = for {
x <- Breakable(toInf(0), foo)
} yield {
if (x % 2 == 1) continue(foo)
if (x > 20) break(foo)
x + 1
}
def test4(n: Int = 100, k: Int = 2) = for {
x <- Breakable(toInf(0), foo)
y <- Breakable(toInf(0), goo)
} yield {
if (x % k == 1) continue(foo)
if (y % k == 0) continue(goo)
if (x > n) break(foo)
if (y > n) break(goo)
(x, y)
}
def test5(strm: => Stream[Int]) = for {
x <- Breakable(strm, foo)
y <- Breakable(strm, goo)
} yield {
if (x % 2 == 0) break(foo)
if (y % 2 == 0) break(goo)
(x, y)
}
def test6 = for {
x <- Breakable(Vector(1, 2, 3), foo)
y <- Breakable(Vector(5, 6, 7), goo)
} {
if (x % 2 == 0) break(foo)
if (y % 2 == 0) break(goo)
println(s"${(x, y)}")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment