Last active
January 24, 2019 20:29
-
-
Save Timshel/5d3c6e56734043525e9a64acc03a4806 to your computer and use it in GitHub Desktop.
RateLimiter (call per time and parallelism)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import akka.stream.{ActorMaterializer, OverflowStrategy} | |
import akka.stream.scaladsl.{Keep, Sink, Source} | |
import scala.concurrent.{ExecutionContext, Future, Promise} | |
import scala.concurrent.duration.FiniteDuration | |
class RateLimiter( | |
val limit: Int, | |
val time: FiniteDuration, | |
val parallelism: Int, | |
val bufferSize: Int | |
)( | |
implicit | |
materializer: ActorMaterializer | |
) { | |
val (input, out) = Source.actorRef[() => Future[_]](bufferSize, OverflowStrategy.dropHead) | |
.via(new SlidingThrottle(limit, time)) | |
.mapAsync(parallelism){ call => call() } | |
.toMat(Sink.ignore)(Keep.both) | |
.run() | |
def enqueue[T](call: => Future[T]): Future[T] = { | |
val promise = Promise[T] | |
val wrapped = { () => promise.completeWith(call) } | |
input ! wrapped | |
promise.future | |
} | |
} | |
class RateLimiterWithTimeout( | |
limit: Int, | |
time: FiniteDuration, | |
timeout: FiniteDuration, | |
parallelism: Int | |
)( | |
implicit | |
bufferSize: Int = RateLimiterWithTimeout.bufferSize(limit, time, timeout), | |
ex: ExecutionContext, | |
system: akka.actor.ActorSystem, | |
materializer: ActorMaterializer | |
) extends RateLimiter(limit, time, parallelism, bufferSize)(materializer) { | |
override def enqueue[T](call: => Future[T]): Future[T] = { | |
import FutureHelpers.RichFuture | |
super.enqueue(call).withTimeout(timeout) | |
} | |
} | |
object RateLimiterWithTimeout { | |
/** | |
* The size is equal to 90% of the capacity than can be handled during the timeout period. | |
* We drop the oldest element when overflowing | |
* It cannot be lower than the limit. | |
* Ex : limit 100, time: 1.s timeout 1.min => bufferSize = 5400 | |
*/ | |
def bufferSize(limit: Int, time: FiniteDuration, timeout: FiniteDuration): Int = | |
math.max(limit, ((timeout / time) * limit * 0.9d).toInt) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.concurrent.TimeoutException | |
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpec} | |
import org.scalatest.concurrent.ScalaFutures | |
import org.scalatest.time.SpanSugar._ | |
import scala.concurrent.{Await, Future} | |
class RateLimiterSpec extends WordSpec with Matchers with ScalaFutures with BeforeAndAfterAll{ | |
import scala.concurrent.ExecutionContext.Implicits.global | |
implicit val system = akka.actor.ActorSystem() | |
implicit val materializer = akka.stream.ActorMaterializer() | |
def fd(span: org.scalatest.time.Span): scala.concurrent.duration.FiniteDuration = { | |
new scala.concurrent.duration.FiniteDuration(span.length, span.unit) | |
} | |
"RateLimiter" should { | |
"enqueue" in { | |
val queue = new RateLimiter(10, fd(1.second), 20, 100) | |
val res = queue.enqueue(Future(12)) | |
assert(res.isReadyWithin(200.millis)) | |
assert(res.futureValue === 12) | |
} | |
"limit" in { | |
val queue = new RateLimiter(1, fd(2.seconds), 2, 100) | |
(0 until 1).foreach { _ => queue.enqueue(Future { 1 }) } | |
val res = queue.enqueue(Future { 2 }) | |
val _ = intercept[TimeoutException] { Await.result(res, 1.second) } | |
assert(res.isReadyWithin(2.seconds)) | |
assert(res.futureValue === 2) | |
} | |
"limit burst" in { | |
val queue = new RateLimiter(10, fd(2.seconds), 20, 100) | |
val start = System.currentTimeMillis | |
val delay = 20 | |
val values = Future.sequence( | |
(0 until 3 * queue.limit).map { i => queue.enqueue(Future { i -> (System.currentTimeMillis - start) }) } | |
) | |
assert(values.isReadyWithin(queue.time * 3)) | |
values.futureValue.sortBy(_._1).grouped(queue.limit).zipWithIndex.foreach { case (grouped, index) => | |
assert(grouped.head._2 < index * queue.time.toMillis + (index + 1) * delay) | |
grouped.map(_._2).sliding(2, 1).map { | |
case Seq(a, b) => assert(b - a < delay) | |
} | |
} | |
} | |
"limit parallelism" in { | |
val queue = new RateLimiter(10, fd(3.seconds), 5, 100) | |
val start = System.currentTimeMillis | |
val sleep = 1.second | |
val delay = 20 | |
val values = Future.sequence( | |
(0 until 3 * queue.limit).map { i => queue.enqueue(Future { | |
Thread.sleep(sleep.toMillis) | |
i -> (System.currentTimeMillis - start) | |
}) } | |
) | |
assert(values.isReadyWithin(queue.time * 3)) | |
values.futureValue.sortBy(_._1).grouped(queue.limit).zipWithIndex.foreach { case (groupedL, indexL) => | |
groupedL.grouped(queue.parallelism).zipWithIndex.foreach { case (groupedP, indexP) => | |
val limit = indexL * queue.time.toMillis + (indexP + 1) * sleep.toMillis + (indexL + indexP + 1) * delay | |
assert(groupedP.head._2 < limit) | |
groupedP.map(_._2).sliding(2, 1).map { | |
case Seq(a, b) => assert(b - a < delay) | |
} | |
} | |
} | |
} | |
} | |
"RateLimiterWithTimeout" should { | |
"timeout" in { | |
val queue = new RateLimiterWithTimeout(10, fd(1.minute), fd(1.second), 100) | |
val res = queue.enqueue(Future { Thread.sleep(1.minute.toMillis) }) | |
val _ = intercept[TimeoutException] { | |
Await.result(res, 1.hour) | |
} | |
} | |
} | |
override def afterAll = { | |
materializer.shutdown | |
system.terminate().futureValue | |
() | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import akka.stream.impl.fusing.GraphStages.SimpleLinearGraphStage | |
import akka.stream.stage._ | |
import akka.stream._ | |
import scala.concurrent.duration.{ FiniteDuration, _ } | |
/** | |
* Implementation of a Throttle with a sliding window | |
* Inspired by () | |
*/ | |
class SlidingThrottle[T](max: Int, per: FiniteDuration) extends SimpleLinearGraphStage[T] { | |
require(max > 0, "max must be > 0") | |
require(per.toNanos > 0, "per time must be > 0") | |
require(per.toNanos >= max, "Rates larger than 1 unit / nanosecond are not supported") | |
private val nanosPer = per.toNanos | |
private val timerName: String = "ThrottleTimer" | |
override def createLogic(inheritedAttributes: Attributes): GraphStageLogic = new TimerGraphStageLogic(shape) { | |
var willStop = false | |
var emittedTimes = scala.collection.immutable.Queue.empty[Long] | |
var last: Long = System.nanoTime | |
var currentElement: T = _ | |
def pushThenLog(elem: T): Unit = { | |
push(out, elem) | |
last = System.nanoTime | |
emittedTimes = emittedTimes :+ last | |
if( willStop ) completeStage() | |
} | |
def schedule(elem: T, nanos: Long): Unit = { | |
currentElement = elem | |
scheduleOnce(timerName, nanos.nanos) | |
} | |
def receive(elem: T): Unit = { | |
var now = System.nanoTime | |
emittedTimes = emittedTimes.dropWhile { t => t + nanosPer < now } | |
if( emittedTimes.length < max ) pushThenLog(elem) | |
else schedule(elem, emittedTimes.head + nanosPer - System.nanoTime) | |
} | |
// This scope is here just to not retain an extra reference to the handler below. | |
// We can't put this code into preRestart() because setHandler() must be called before that. | |
{ | |
val handler = new InHandler with OutHandler { | |
override def onUpstreamFinish(): Unit = | |
if (isAvailable(out) && isTimerActive(timerName)) willStop = true | |
else completeStage() | |
override def onPush(): Unit = receive(grab(in)) | |
override def onPull(): Unit = pull(in) | |
} | |
setHandler(in, handler) | |
setHandler(out, handler) | |
// After this point, we no longer need the `handler` so it can just fall out of scope. | |
} | |
override protected def onTimer(key: Any): Unit = { | |
var elem = currentElement | |
currentElement = null.asInstanceOf[T] | |
receive(elem) | |
} | |
} | |
override def toString = "Throttle" | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment