Last active
May 26, 2020 15:37
-
-
Save LMnet/fc95146418caeeb8796decadc2181e65 to your computer and use it in GitHub Desktop.
StreamPriorityUtils
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cats.Applicative | |
import cats.data.NonEmptyList | |
import cats.effect.Concurrent | |
import cats.effect.concurrent.{Deferred, Ref} | |
import cats.implicits._ | |
import fs2.concurrent.{InspectableQueue, SignallingRef} | |
import fs2.{Chunk, CompositeFailure, Scope, Stream} | |
object StreamPriorityUtils { | |
/** | |
* Nondeterministically merges streams to a single stream. | |
* Order of the streams is important: left streams will have more priority. | |
* If there is some data in the leftmost stream, the resulting stream will produce this data. | |
* When leftmost stream hangs, resulting stream will try to get data from the next stream and so on. | |
* | |
* Every stream has a buffer for messages, | |
* to prefetch some data from all input streams before prioritized getting data. | |
*/ | |
def parJoinPrioritized[F[_]: Concurrent, A]( | |
streamsWithBuffers: (fs2.Stream[F, A], Int)* | |
): fs2.Stream[F, A] = { | |
val F = Concurrent[F] | |
Stream.eval[F, fs2.Stream[F, A]] { | |
SignallingRef(None: Option[Option[Throwable]]).flatMap[fs2.Stream[F, A]] { streamsDone => | |
SignallingRef(false).flatMap { allDone => | |
Ref.of(0L).map { running => | |
// stops the join evaluation | |
// all the streams will be terminated. If err is supplied, that will get attached to any error currently present | |
def stop(rslt: Option[Throwable]): F[Unit] = { | |
streamsDone.update { | |
case rslt0@Some(Some(err0)) => | |
rslt.fold[Option[Option[Throwable]]](rslt0) { err => | |
Some(Some(new CompositeFailure(err0, NonEmptyList.of(err)))) | |
} | |
case _ => Some(rslt) | |
} | |
} | |
val incrementRunning: F[Unit] = running.update(_ + 1) | |
val decrementRunning: F[Unit] = { | |
running.modify { n => | |
val now = n - 1 | |
now -> (if (now == 0) stop(None) else F.unit) | |
}.flatten | |
} | |
// runs inner stream | |
// each stream is forked. | |
// terminates when killSignal is true | |
// if fails will enq in queue failure | |
// note that supplied scope's resources must be leased before the inner stream forks the execution to another thread | |
// and that it must be released once the inner stream terminates or fails. | |
def runInner(inner: Stream[F, A], outerScope: Scope[F], buffer: InspectableQueue[F, A]): F[Unit] = { | |
F.uncancelable { | |
outerScope.lease.flatMap { | |
case Some(lease) => | |
incrementRunning >> | |
F.start { | |
inner | |
.evalMap{buffer.enqueue1} | |
.interruptWhen(streamsDone.map(_.nonEmpty)) // must be AFTER enqueue to the sync queue, otherwise the process may hang to enq last item while being interrupted | |
.compile | |
.drain | |
.attempt | |
.flatMap { r => | |
lease.cancel.flatMap { cancelResult => | |
(CompositeFailure.fromResults(r, cancelResult) match { | |
case Right(()) => F.unit | |
case Left(err) => | |
stop(Some(err)) | |
}) >> decrementRunning | |
} | |
} | |
}.void | |
case None => | |
F.raiseError( | |
new Throwable("Outer scope is closed during inner stream startup")) | |
} | |
} | |
} | |
// awaits when all streams (outer + inner) finished, | |
// and then collects result of the stream (outer + inner) execution | |
def signalResult: F[Option[Throwable]] = { | |
streamsDone.get.map(_.flatten) | |
} | |
// creating an InspectableQueue as a buffer for each input stream | |
val buffers = fs2.Stream.getScope[F].evalMap { outerScope => | |
streamsWithBuffers.toList.traverse { case (stream, bufferSize) => { | |
InspectableQueue.bounded[F, A](bufferSize).flatMap { streamBuffer => | |
runInner(stream, outerScope, streamBuffer).as((streamBuffer, bufferSize)) | |
} | |
} | |
} | |
} | |
//starting a loop for retrieving data | |
val loop: fs2.Stream[F, A] = buffers.flatMap[F, A] { streamBuffers: List[(InspectableQueue[F, A], Int)] => | |
fs2.Stream.repeatEval[F, Option[Chunk[A]]] { | |
// trying to get a next chunk from the buffers, starting with the topmost | |
def getNext: F[Option[Chunk[A]]] = { | |
streamBuffers.collectFirstSomeM { case (buffer, bufferSize) => | |
buffer.tryDequeueChunk1(bufferSize) | |
} | |
} | |
// Looping getNext if there is enough data. If there is no data in buffers — | |
// waiting for a new portion of data with `peek1` and after that looping again. | |
def waitForNext: F[Chunk[A]] = { | |
getNext.flatMap { | |
case Some(elem) => | |
Applicative[F].pure(elem) | |
case None => | |
Deferred[F, Unit].flatMap { nextElemWaiter => | |
streamBuffers.traverse { case ( buffer, _ ) => | |
F.start(buffer.peek1.flatMap { _ => | |
nextElemWaiter.complete(()).attempt.void | |
}) | |
}.flatMap { fibers => | |
nextElemWaiter.get >> fibers.map(_.cancel).sequence >> waitForNext | |
} | |
} | |
} | |
} | |
// if stream is in finalizing process, we should return all data from the buffers, | |
// and only after terminate the stream | |
streamsDone.get.flatMap { | |
case None => waitForNext.map(Some(_)) | |
case Some(_) => getNext | |
} | |
}.flatMap { | |
case Some(chunk) => fs2.Stream.chunk(chunk).covary[F] | |
case None => fs2.Stream.eval(allDone.set(true)).drain.covaryAll[F, A] | |
} | |
} | |
loop.interruptWhen(allDone).onComplete { | |
fs2.Stream.eval { | |
stop(None) >> signalResult | |
}.flatMap { | |
case Some(err) => fs2.Stream.raiseError[F](err) | |
case None => fs2.Stream.empty.covaryAll[F, A] | |
} | |
}.scope | |
} | |
} | |
} | |
}.flatten | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment