Skip to content

Instantly share code, notes, and snippets.

@alexandru
Created October 5, 2016 08:19
Show Gist options
  • Select an option

  • Save alexandru/00e572743ee887328c2027c69e4dcf4f to your computer and use it in GitHub Desktop.

Select an option

Save alexandru/00e572743ee887328c2027c69e4dcf4f to your computer and use it in GitHub Desktop.
import monix.execution.Cancelable
import monix.execution.atomic.Atomic
import monix.execution.atomic.PaddingStrategy.LeftRight128
import scala.annotation.tailrec
import scala.collection.immutable.Queue
import scala.concurrent.{Future, Promise}
/** The `TaskSemaphore` is an asynchronous semaphore implementation that
* limits the parallelism on task execution.
*
* The following example instantiates a semaphore with a
* maximum parallelism of 10:
*
* {{{
* val semaphore = TaskSemaphore(maxParallelism = 10)
*
* def makeRequest(r: HttpRequest): Task[HttpResponse] = ???
*
* // For such a task no more than 10 requests
* // are allowed to be executed in parallel.
* val task = semaphore.greenLight(makeRequest(???))
* }}}
*/
final class TaskSemaphore private (maxParallelism: Int) extends Serializable {
import TaskSemaphore.State
require(maxParallelism > 0, "parallelism > 0")
private[this] val stateRef =
Atomic.withPadding(TaskSemaphore.initialState, LeftRight128)
/** Returns the number of active tasks that are holding on
* to the available permits.
*/
def activeCount: Int =
stateRef.get.activeCount
/** Creates a new task ensuring that the given source
* acquires an available permit from the semaphore before
* it is being executed.
*
* The returned task also takes care of resource handling,
* releasing its permit after being complete.
*/
def greenLight[A](fa: Task[A]): Task[A] =
Task.unsafeCreate { (s, conn, cb) =>
val permit = acquire()
val c = Cancelable(release)
// On cancel trigger a release
conn.push(c)
val source = Task.fromFuture(permit).flatMap { _ =>
// On finish trigger a release
fa.doOnFinish(_ => Task.eval(c.cancel()))
}
Task.unsafeStartNow(source, s, conn, cb)
}
/** Internal. Releases a permit, returning it to the pool. */
@tailrec private def release(): Unit =
stateRef.get match {
case current @ State(activeCount, promises) =>
val (p, newPromises) =
if (promises.nonEmpty) promises.dequeue else (null, promises)
val newActiveCount =
if (p != null) activeCount else activeCount - 1
val update =
State(newActiveCount, newPromises)
if (!stateRef.compareAndSet(current, update))
release() // retry
else if (p != null)
p.trySuccess(())
}
/** Internal. Triggers a permit acquisition,
* returning a future that will complete when a
* permit gets acquired.
*/
@tailrec private def acquire(): Future[Unit] =
stateRef.get match {
case current @ State(activeCount, _) =>
if (activeCount < maxParallelism) {
val update = current.activateOne()
if (!stateRef.compareAndSet(current, update))
acquire() // retry
else
TaskSemaphore.availablePermit
}
else {
val p = Promise[Unit]()
val update = current.addPromise(p)
if (!stateRef.compareAndSet(current, update))
acquire() // retry
else
p.future
}
}
}
object TaskSemaphore {
/** Builder for [[TaskSemaphore]].
*
* @param maxParallelism represents the number of tasks allowed for
* parallel execution
*/
def apply(maxParallelism: Int): TaskSemaphore =
new TaskSemaphore(maxParallelism)
private final val availablePermit =
Future.successful(())
private final val initialState: State =
State(0, Queue.empty)
private final case class State(
activeCount: Int,
promises: Queue[Promise[Unit]]) {
def activateOne(): State =
copy(activeCount = activeCount + 1)
def addPromise(p: Promise[Unit]): State =
copy(promises = promises.enqueue(p))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment