I have improved and packaged the ideas in this gist here:
Last active
March 18, 2020 17:49
-
-
Save erikerlandson/f1726bf8a9a097edac1c18f7af8d732d to your computer and use it in GitHub Desktop.
Monads for using break and continue in scala for comprehensions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object demo { | |
import scala.language.implicitConversions | |
import scala.language.higherKinds | |
import scala.collection.SeqLike | |
import scala.collection.TraversableLike | |
import scala.collection.generic.CanBuildFrom | |
import scala.util.{ Try, Success, Failure } | |
import scala.annotation.tailrec | |
trait Label | |
sealed trait Control extends Exception | |
case class Break(label: Label) extends Control | |
case class Continue(label: Label) extends Control | |
def break(lab: Label): Unit = { throw Break(lab) } | |
def continue(lab: Label): Unit = { throw Continue(lab) } | |
class Breakable[A](stream: () => Stream[Try[A]], label: Label, p: A => Boolean) { | |
import Breakable._ | |
def toStream: Stream[A] = stream().map { _.get } | |
def toIterator: Iterator[A] = toStream.iterator | |
private def rawStream = stream() | |
def map[B](f: A => B): Breakable[B] = { | |
new Breakable(() => trMap(f, stream()), noLabel, pTrue[B]) | |
} | |
private def trMap[B](f: A => B, s: Stream[Try[A]]): Stream[Try[B]] = { | |
if (s.isEmpty) Stream.empty[Try[B]] else s.head.map(f) match { | |
case h @ Success(_) => h #:: trMap(f, s.tail) | |
case Failure(Continue(`label`)) => trMap(f, s.tail) | |
case Failure(Break(`label`)) => Stream.empty[Try[B]] | |
case h @ Failure(_: Control) => Stream(h) | |
case Failure(t: Throwable) => throw(t) | |
} | |
} | |
def flatMap[B](f: A => Breakable[B]): Breakable[B] = { | |
new Breakable(() => trFlatMap(f, stream()), noLabel, pTrue[B]) | |
} | |
private def trFlatMap[B](f: A => Breakable[B], s: Stream[Try[A]]): Stream[Try[B]] = { | |
if (s.isEmpty) Stream.empty[Try[B]] else s.head.map(f) match { | |
case Success(b) => trFlatMapContinue(b.rawStream, () => trFlatMap(f, s.tail)) | |
case Failure(Continue(`label`)) => trFlatMap(f, s.tail) | |
case Failure(Break(`label`)) => Stream.empty[Try[B]] | |
case Failure(c: Control) => Stream(Failure(c)) | |
case Failure(t: Throwable) => throw(t) | |
} | |
} | |
private def trFlatMapContinue[B](s: Stream[Try[B]], cont: () => Stream[Try[B]]): Stream[Try[B]] = { | |
if (s.isEmpty) cont() else s.head match { | |
case h @ Success(_) => h #:: trFlatMapContinue(s.tail, cont) | |
case Failure(Continue(`label`)) => cont() | |
case Failure(Break(`label`)) => Stream.empty[Try[B]] | |
case h @ Failure(_: Control) => Stream(h) | |
case Failure(t: Throwable) => throw(t) | |
} | |
} | |
def foreach[U](f: A => U): Unit = trForeach(f, stream()) | |
@tailrec private def trForeach[U](f: A => U, s: Stream[Try[A]]): Unit = { | |
if (!s.isEmpty) s.head.map(f) match { | |
case Success(_) => trForeach(f, s.tail) | |
case Failure(Continue(`label`)) => trForeach(f, s.tail) | |
case Failure(Break(`label`)) => () | |
case Failure(ctrl) => throw ctrl | |
} | |
} | |
def withFilter(q: A => Boolean): Breakable[A] = | |
new Breakable(stream, label, a => p(a) && q(a)) | |
} | |
object Breakable { | |
object noLabel extends Label | |
def pTrue[A]: A => Boolean = (_: A) => true | |
def apply[A](t: => Seq[A]) = new Breakable(() => t.toStream.map(Success(_)), noLabel, pTrue[A]) | |
def apply[A](t: => Seq[A], lab: Label) = new Breakable(() => t.toStream.map(Success(_)), lab, pTrue[A]) | |
} | |
object foo extends Label | |
object goo extends Label | |
object moo extends Label | |
def toInf(v: Int): Stream[Int] = v #:: toInf(v + 1) | |
def test1 = for { | |
x <- Breakable(Vector(1, 2, 3), foo) | |
} yield (x + 1) | |
def test2 = for { | |
x <- Breakable(toInf(0), foo) | |
} yield (x + 1) | |
def test3 = for { | |
x <- Breakable(toInf(0), foo) | |
} yield { | |
if (x % 2 == 1) continue(foo) | |
if (x > 20) break(foo) | |
x + 1 | |
} | |
def test4(n: Int = 100, k: Int = 2) = for { | |
x <- Breakable(toInf(0), foo) | |
y <- Breakable(toInf(0), goo) | |
} yield { | |
if (x % k == 1) continue(foo) | |
if (y % k == 0) continue(goo) | |
if (x > n) break(foo) | |
if (y > n) break(goo) | |
(x, y) | |
} | |
def test5(strm: => Stream[Int]) = for { | |
x <- Breakable(strm, foo) | |
y <- Breakable(strm, goo) | |
} yield { | |
if (x % 2 == 0) break(foo) | |
if (y % 2 == 0) break(goo) | |
(x, y) | |
} | |
def test6 = for { | |
x <- Breakable(Vector(1, 2, 3), foo) | |
y <- Breakable(Vector(5, 6, 7), goo) | |
} { | |
if (x % 2 == 0) break(foo) | |
if (y % 2 == 0) break(goo) | |
println(s"${(x, y)}") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment