Created
April 29, 2021 10:17
-
-
Save tg44/aa1d279b247d74e0ebca1489dc643410 to your computer and use it in GitHub Desktop.
AdaptiveQueueSource akka streams extension
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package akka.streams | |
import akka.Done | |
import akka.stream.OverflowStrategies.{Backpressure, DropBuffer, DropHead, DropNew, DropTail, Fail} | |
import akka.stream.{ | |
Attributes, | |
BufferOverflowException, | |
Outlet, | |
OverflowStrategy, | |
QueueOfferResult, | |
SourceShape, | |
StreamDetachedException, | |
} | |
import akka.stream.impl.Buffer | |
import akka.stream.impl.Stages.DefaultAttributes | |
import akka.stream.scaladsl.{Source, SourceQueueWithComplete} | |
import akka.stream.stage.{GraphStageLogic, GraphStageWithMaterializedValue, OutHandler, StageLogging} | |
import akka.streams.AdaptiveQueueSource.SourceQueueWithCompleteAndSize | |
import scala.concurrent.{Future, Promise} | |
object AdaptiveQueueSource { | |
sealed trait Input[+T] | |
final case class Offer[+T](elem: T, promise: Promise[QueueOfferResult]) extends Input[T] | |
case object Completion extends Input[Nothing] | |
final case class Failure(ex: Throwable) extends Input[Nothing] | |
def priorityQueue[T: Ordering]( | |
bufferSize: Int, | |
overflowStrategy: OverflowStrategy, | |
): Source[T, SourceQueueWithCompleteAndSize[T]] = | |
Source.fromGraph(new AdaptiveQueueSource( | |
() => new FixedSizePriorityBuffer(bufferSize), | |
overflowStrategy, | |
).withAttributes(DefaultAttributes.queueSource)) | |
def queue[T](bufferSize: Int, overflowStrategy: OverflowStrategy): Source[T, SourceQueueWithCompleteAndSize[T]] = | |
Source.fromGraph(new AdaptiveQueueSource(() => Buffer[T](bufferSize, 1000000000), overflowStrategy).withAttributes( | |
DefaultAttributes.queueSource | |
)) | |
class FixedSizePriorityBuffer[T: Ordering](val capacity: Int) extends Buffer[T] { | |
override def toString = s"PriorityBuffer($capacity)(${buffer.clone.dequeueAll})" | |
private val buffer = collection.mutable.PriorityQueue.empty[T] | |
def used: Int = buffer.size | |
def isFull: Boolean = used >= capacity | |
def nonFull: Boolean = used < capacity | |
def remainingCapacity: Int = { | |
val rem = capacity - used | |
if(rem > 0) rem else 0 | |
} | |
def isEmpty: Boolean = used == 0 | |
def nonEmpty: Boolean = used != 0 | |
def enqueue(elem: T): Unit = buffer.enqueue(elem) | |
def peek(): T = buffer.head | |
def dequeue(): T = buffer.dequeue() | |
def clear(): Unit = buffer.clear() | |
def dropHead(): Unit = buffer.dequeue() | |
def dropTail(): Unit = buffer.dropRight(1) | |
} | |
trait SourceQueueWithCompleteAndSize[T] extends SourceQueueWithComplete[T] { | |
def used: Int | |
def capacity: Int | |
def isFull: Boolean | |
def isEmpty: Boolean | |
} | |
} | |
final class AdaptiveQueueSource[T](queueCreator: () => Buffer[T], overflowStrategy: OverflowStrategy) | |
extends GraphStageWithMaterializedValue[SourceShape[T], SourceQueueWithCompleteAndSize[T]] { | |
import AdaptiveQueueSource._ | |
val out = Outlet[T]("queueSource.out") | |
override val shape: SourceShape[T] = SourceShape.of(out) | |
override def createLogicAndMaterializedValue(inheritedAttributes: Attributes) = { | |
val completion = Promise[Done] | |
val stageLogic = | |
new GraphStageLogic(shape) with OutHandler with SourceQueueWithCompleteAndSize[T] with StageLogging { | |
override protected def logSource: Class[_] = classOf[AdaptiveQueueSource[_]] | |
val buffer: Buffer[T] = queueCreator() | |
var pendingOffer: Option[Offer[T]] = None | |
var terminating = false | |
override def postStop(): Unit = { | |
val exception = new StreamDetachedException() | |
completion.tryFailure(exception) | |
} | |
private def enqueueAndSuccess(offer: Offer[T]): Unit = { | |
buffer.enqueue(offer.elem) | |
offer.promise.success(QueueOfferResult.Enqueued) | |
} | |
private def bufferElem(offer: Offer[T]): Unit = { | |
if(!buffer.isFull) { | |
enqueueAndSuccess(offer) | |
} else | |
overflowStrategy match { | |
case s: DropHead => | |
log.log( | |
s.logLevel, | |
"Dropping the head element because buffer is full and overflowStrategy is: [DropHead]", | |
) | |
buffer.dropHead() | |
enqueueAndSuccess(offer) | |
case s: DropTail => | |
log.log( | |
s.logLevel, | |
"Dropping the tail element because buffer is full and overflowStrategy is: [DropTail]", | |
) | |
buffer.dropTail() | |
enqueueAndSuccess(offer) | |
case s: DropBuffer => | |
log.log( | |
s.logLevel, | |
"Dropping all the buffered elements because buffer is full and overflowStrategy is: [DropBuffer]", | |
) | |
buffer.clear() | |
enqueueAndSuccess(offer) | |
case s: DropNew => | |
log.log( | |
s.logLevel, | |
"Dropping the new element because buffer is full and overflowStrategy is: [DropNew]", | |
) | |
offer.promise.success(QueueOfferResult.Dropped) | |
case s: Fail => | |
log.log(s.logLevel, "Failing because buffer is full and overflowStrategy is: [Fail]") | |
val bufferOverflowException = | |
BufferOverflowException(s"Buffer overflow (max capacity was: ${buffer.capacity})!") | |
offer.promise.success(QueueOfferResult.Failure(bufferOverflowException)) | |
completion.failure(bufferOverflowException) | |
failStage(bufferOverflowException) | |
case s: Backpressure => | |
log.log(s.logLevel, "Backpressuring because buffer is full and overflowStrategy is: [Backpressure]") | |
pendingOffer match { | |
case Some(_) => | |
offer.promise.failure( | |
new IllegalStateException( | |
"You have to wait for the previous offer to be resolved to send another request" | |
) | |
) | |
case None => | |
pendingOffer = Some(offer) | |
} | |
} | |
} | |
private val callback = getAsyncCallback[Input[T]] { | |
case Offer(_, promise) if terminating => | |
promise.success(QueueOfferResult.Dropped) | |
case offer @ Offer(elem, promise) => | |
bufferElem(offer) | |
if(isAvailable(out)) push(out, buffer.dequeue()) | |
case Completion => | |
if(buffer.nonEmpty || pendingOffer.nonEmpty) terminating = true | |
else { | |
completion.success(Done) | |
completeStage() | |
} | |
case Failure(ex) => | |
completion.failure(ex) | |
failStage(ex) | |
} | |
setHandler(out, this) | |
override def onDownstreamFinish(): Unit = { | |
pendingOffer match { | |
case Some(Offer(_, promise)) => | |
promise.success(QueueOfferResult.QueueClosed) | |
pendingOffer = None | |
case None => // do nothing | |
} | |
completion.success(Done) | |
completeStage() | |
} | |
override def onPull(): Unit = { | |
if(buffer.nonEmpty) { | |
push(out, buffer.dequeue()) | |
pendingOffer match { | |
case Some(offer) => | |
enqueueAndSuccess(offer) | |
pendingOffer = None | |
case None => //do nothing | |
} | |
if(terminating && buffer.isEmpty) { | |
completion.success(Done) | |
completeStage() | |
} | |
} | |
} | |
override def watchCompletion() = completion.future | |
override def offer(element: T): Future[QueueOfferResult] = { | |
val p = Promise[QueueOfferResult] | |
callback | |
.invokeWithFeedback(Offer(element, p)) | |
.onComplete { | |
case scala.util.Success(_) => | |
case scala.util.Failure(e) => p.tryFailure(e) | |
}(akka.dispatch.ExecutionContexts.sameThreadExecutionContext) | |
p.future | |
} | |
override def complete(): Unit = callback.invoke(Completion) | |
override def fail(ex: Throwable): Unit = callback.invoke(Failure(ex)) | |
override def used: Int = buffer.used | |
override def capacity: Int = buffer.capacity | |
override def isFull: Boolean = buffer.isFull | |
override def isEmpty: Boolean = buffer.isEmpty | |
} | |
(stageLogic, stageLogic) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment