Created
June 15, 2016 17:38
-
-
Save randomstatistic/6dc13fc80ebf30b97c09ae52002895b8 to your computer and use it in GitHub Desktop.
Leaky bucket implementation in scala
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.concurrent.locks.ReentrantLock | |
import scala.concurrent.{ExecutionContext, Future, Promise} | |
import scala.concurrent.duration._ | |
class LeakyBucket(dripEvery: FiniteDuration, maxSize: Int) { | |
require(maxSize > 0, "A bucket must have a size > 0") | |
private val dripEveryNanos = dripEvery.toNanos | |
private val lock = new ReentrantLock() | |
private def withLock[T](f: => T) = { | |
lock.lock() | |
try { f } finally { lock.unlock() } | |
} | |
private var bucket = 0 | |
private var lastFilled = System.nanoTime() | |
private def nextFill = lastFilled + dripEveryNanos | |
private def refillBucket(upTo: Int = maxSize) { | |
val now = System.nanoTime() | |
if (now >= nextFill) { | |
val tokensGeneratedSinceLastRun = ((now - lastFilled) / dripEveryNanos).toInt | |
lastFilled = lastFilled + tokensGeneratedSinceLastRun * dripEveryNanos | |
bucket = upTo min (bucket + tokensGeneratedSinceLastRun) | |
} | |
} | |
private def waitForRefills(num: Int) { | |
val now = System.nanoTime() | |
if (now < nextFill) { | |
//Because thread.sleep requires millis | |
val nextFillMillis = (lastFilled - now + dripEveryNanos * num) / 1000000 | |
// Could compute nextFillMillis with remainder and use the sleep(millis, nanos) api instead, I suppose | |
Thread.sleep(nextFillMillis max 1) | |
} | |
} | |
def drain() { | |
withLock { | |
refillBucket() | |
bucket = 0 | |
} | |
} | |
// -- Sync apis -- | |
def awaitToken(num: Int = 1) { | |
require(num > 0) | |
withLock { | |
assert(bucket >= 0) | |
refillBucket() | |
if (bucket >= num) { | |
bucket = bucket - num | |
} | |
else { | |
val soFar = bucket | |
val remaining = num - soFar | |
bucket = 0 | |
waitForRefills(remaining min maxSize) | |
awaitToken(remaining) | |
} | |
} | |
} | |
def rateLimited[T](num: Int = 1)(f: => T): T = { | |
awaitToken(num) | |
f | |
} | |
def iterator(size: Int, tokens: Int = 1) = Range(0,size).iterator.map(i => rateLimited(tokens){ i }) | |
// -- Async apis -- | |
def getToken(num: Int)(implicit ec: ExecutionContext): Future[Unit] = Future{ awaitToken(num) }(ec) | |
def rateLimitedAsync[T](num: Int)(f: => T)(implicit ec: ExecutionContext): Future[T] = { | |
getToken(num).map( _ => f ) | |
} | |
// Note, doesn't block the parameter future from executing, only anything chained off of that | |
def rateLimitedFuture[T](num: Int)(f: Future[T])(implicit ec: ExecutionContext): Future[T] = { | |
getToken(num).flatMap( _ => f ) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.concurrent.Executors | |
import org.scalatest.{FunSpec, Matchers} | |
import scala.concurrent.{Await, ExecutionContext, Future} | |
import scala.concurrent.duration._ | |
class TestLeakyBucket extends FunSpec with Matchers { | |
implicit val ec = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(10)) | |
val times = scala.collection.mutable.HashMap[String, FiniteDuration]() | |
def time[T](name: String)(f: => T) = { | |
val t1 = System.nanoTime() | |
try { | |
f | |
} finally { | |
times += (name -> (System.nanoTime() - t1).nanos) | |
} | |
} | |
def bucketTest(bucketRate: FiniteDuration, bucketSize: Int, initialDelay: FiniteDuration)(block: (LeakyBucket) => Unit): Unit = { | |
val bucket = new LeakyBucket(bucketRate, bucketSize) | |
if (initialDelay > 0.seconds) Thread.sleep(initialDelay.toMillis) else bucket.drain() | |
block(bucket) | |
} | |
describe("100ms rate") { | |
it("should consume a half-full bucket quickly") { | |
bucketTest(100.millis, 10, 500.millis) { bucket => { | |
time("half-full bucket") { | |
bucket.awaitToken(5) | |
} | |
times("half-full bucket").toMillis should be < 10L | |
}} | |
} | |
it("should consume a full bucket quickly") { | |
bucketTest(100.millis, 10, 1100.millis) { bucket => { | |
time("full bucket") { | |
bucket.awaitToken(10) | |
} | |
times("full bucket").toMillis should be < 10L | |
}} | |
} | |
it("should consume from an empty bucket at the expected rate") { | |
bucketTest(100.millis, 10, 0.millis) { bucket => { | |
Range(1, 10).foreach(i => { | |
val timerName = "empty bucket run " + i | |
time(timerName) { | |
bucket.awaitToken() | |
} | |
times(timerName).toMillis should be(100L +- 15) | |
}) | |
}} | |
} | |
it("asking a full bucket for more than it has should not be quick") { | |
bucketTest(100.millis, 10, 1100.millis) { bucket => { | |
time("full bucket overflow") { | |
bucket.awaitToken(11) | |
} | |
times("full bucket overflow").toMillis should be(100L +- 15) | |
}} | |
} | |
} | |
describe("10ms rate") { | |
it("should consume from an empty bucket at the expected rate") { | |
bucketTest(10.millis, 100, 0.millis) { bucket => { | |
time("throughput") { | |
Range(0, 100).foreach(i => { | |
bucket.awaitToken() | |
}) | |
} | |
times("throughput").toMillis should be (1000L +- 25) | |
}} | |
} | |
it("gets the expected rate with concurrent consumers") { | |
var futures = List[Future[Unit]]() | |
bucketTest(10.millis, 100, 0.millis) { bucket => { | |
time("concurrent throughput") { | |
Range(0, 100).foreach( _ => { | |
futures = futures :+ bucket.getToken(1) | |
}) | |
Await.ready(Future.sequence(futures), 2.seconds) | |
} | |
times("concurrent throughput").toMillis should be (1000L +- 25) | |
}} | |
} | |
} | |
describe("rate limited actions") { | |
it ("should limit synchronous rate") { | |
bucketTest(10.millis, 1, 0.millis) { bucket => { | |
val results = time("rate limited loop") { | |
for (i <- Range(0, 10)) yield { | |
bucket.rateLimited() { | |
i // some thrilling computation | |
} | |
} | |
} | |
times("rate limited loop").toMillis should be (100L +- 5) | |
results should contain theSameElementsInOrderAs Range(0,10) | |
}} | |
} | |
it("should rate limit async blocks") { | |
bucketTest(50.millis, 1, 0.millis) { bucket => { | |
val f = bucket.rateLimitedAsync(1) { | |
1 | |
} | |
val result = time("future completion") { | |
Await.result(f, 1.seconds) | |
} | |
times("future completion").toMillis should be (50L +- 5) | |
result should be(1) | |
}} | |
} | |
it("should rate limit future completion") { | |
bucketTest(10.millis, 1, 0.millis) { bucket => { | |
val f = bucket.rateLimitedFuture(5)(Future.successful(1)) | |
val result = time("future completion") { | |
Await.result(f, 1.seconds) | |
} | |
times("future completion").toMillis should be (50L +- 5) | |
result should be(1) | |
}} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment