Skip to content

Instantly share code, notes, and snippets.

@LMnet
Last active May 26, 2020 15:37
Show Gist options
  • Save LMnet/fc95146418caeeb8796decadc2181e65 to your computer and use it in GitHub Desktop.
Save LMnet/fc95146418caeeb8796decadc2181e65 to your computer and use it in GitHub Desktop.
StreamPriorityUtils
import cats.Applicative
import cats.data.NonEmptyList
import cats.effect.Concurrent
import cats.effect.concurrent.{Deferred, Ref}
import cats.implicits._
import fs2.concurrent.{InspectableQueue, SignallingRef}
import fs2.{Chunk, CompositeFailure, Scope, Stream}
object StreamPriorityUtils {
/**
* Nondeterministically merges streams to a single stream.
* Order of the streams is important: left streams will have more priority.
* If there is some data in the leftmost stream, the resulting stream will produce this data.
* When leftmost stream hangs, resulting stream will try to get data from the next stream and so on.
*
* Every stream has a buffer for messages,
* to prefetch some data from all input streams before prioritized getting data.
*/
def parJoinPrioritized[F[_]: Concurrent, A](
streamsWithBuffers: (fs2.Stream[F, A], Int)*
): fs2.Stream[F, A] = {
val F = Concurrent[F]
Stream.eval[F, fs2.Stream[F, A]] {
SignallingRef(None: Option[Option[Throwable]]).flatMap[fs2.Stream[F, A]] { streamsDone =>
SignallingRef(false).flatMap { allDone =>
Ref.of(0L).map { running =>
// stops the join evaluation
// all the streams will be terminated. If err is supplied, that will get attached to any error currently present
def stop(rslt: Option[Throwable]): F[Unit] = {
streamsDone.update {
case rslt0@Some(Some(err0)) =>
rslt.fold[Option[Option[Throwable]]](rslt0) { err =>
Some(Some(new CompositeFailure(err0, NonEmptyList.of(err))))
}
case _ => Some(rslt)
}
}
val incrementRunning: F[Unit] = running.update(_ + 1)
val decrementRunning: F[Unit] = {
running.modify { n =>
val now = n - 1
now -> (if (now == 0) stop(None) else F.unit)
}.flatten
}
// runs inner stream
// each stream is forked.
// terminates when killSignal is true
// if fails will enq in queue failure
// note that supplied scope's resources must be leased before the inner stream forks the execution to another thread
// and that it must be released once the inner stream terminates or fails.
def runInner(inner: Stream[F, A], outerScope: Scope[F], buffer: InspectableQueue[F, A]): F[Unit] = {
F.uncancelable {
outerScope.lease.flatMap {
case Some(lease) =>
incrementRunning >>
F.start {
inner
.evalMap{buffer.enqueue1}
.interruptWhen(streamsDone.map(_.nonEmpty)) // must be AFTER enqueue to the sync queue, otherwise the process may hang to enq last item while being interrupted
.compile
.drain
.attempt
.flatMap { r =>
lease.cancel.flatMap { cancelResult =>
(CompositeFailure.fromResults(r, cancelResult) match {
case Right(()) => F.unit
case Left(err) =>
stop(Some(err))
}) >> decrementRunning
}
}
}.void
case None =>
F.raiseError(
new Throwable("Outer scope is closed during inner stream startup"))
}
}
}
// awaits when all streams (outer + inner) finished,
// and then collects result of the stream (outer + inner) execution
def signalResult: F[Option[Throwable]] = {
streamsDone.get.map(_.flatten)
}
// creating an InspectableQueue as a buffer for each input stream
val buffers = fs2.Stream.getScope[F].evalMap { outerScope =>
streamsWithBuffers.toList.traverse { case (stream, bufferSize) => {
InspectableQueue.bounded[F, A](bufferSize).flatMap { streamBuffer =>
runInner(stream, outerScope, streamBuffer).as((streamBuffer, bufferSize))
}
}
}
}
//starting a loop for retrieving data
val loop: fs2.Stream[F, A] = buffers.flatMap[F, A] { streamBuffers: List[(InspectableQueue[F, A], Int)] =>
fs2.Stream.repeatEval[F, Option[Chunk[A]]] {
// trying to get a next chunk from the buffers, starting with the topmost
def getNext: F[Option[Chunk[A]]] = {
streamBuffers.collectFirstSomeM { case (buffer, bufferSize) =>
buffer.tryDequeueChunk1(bufferSize)
}
}
// Looping getNext if there is enough data. If there is no data in buffers —
// waiting for a new portion of data with `peek1` and after that looping again.
def waitForNext: F[Chunk[A]] = {
getNext.flatMap {
case Some(elem) =>
Applicative[F].pure(elem)
case None =>
Deferred[F, Unit].flatMap { nextElemWaiter =>
streamBuffers.traverse { case ( buffer, _ ) =>
F.start(buffer.peek1.flatMap { _ =>
nextElemWaiter.complete(()).attempt.void
})
}.flatMap { fibers =>
nextElemWaiter.get >> fibers.map(_.cancel).sequence >> waitForNext
}
}
}
}
// if stream is in finalizing process, we should return all data from the buffers,
// and only after terminate the stream
streamsDone.get.flatMap {
case None => waitForNext.map(Some(_))
case Some(_) => getNext
}
}.flatMap {
case Some(chunk) => fs2.Stream.chunk(chunk).covary[F]
case None => fs2.Stream.eval(allDone.set(true)).drain.covaryAll[F, A]
}
}
loop.interruptWhen(allDone).onComplete {
fs2.Stream.eval {
stop(None) >> signalResult
}.flatMap {
case Some(err) => fs2.Stream.raiseError[F](err)
case None => fs2.Stream.empty.covaryAll[F, A]
}
}.scope
}
}
}
}.flatten
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment