Created
October 5, 2016 08:19
-
-
Save alexandru/00e572743ee887328c2027c69e4dcf4f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import monix.execution.Cancelable | |
| import monix.execution.atomic.Atomic | |
| import monix.execution.atomic.PaddingStrategy.LeftRight128 | |
| import scala.annotation.tailrec | |
| import scala.collection.immutable.Queue | |
| import scala.concurrent.{Future, Promise} | |
| /** The `TaskSemaphore` is an asynchronous semaphore implementation that | |
| * limits the parallelism on task execution. | |
| * | |
| * The following example instantiates a semaphore with a | |
| * maximum parallelism of 10: | |
| * | |
| * {{{ | |
| * val semaphore = TaskSemaphore(maxParallelism = 10) | |
| * | |
| * def makeRequest(r: HttpRequest): Task[HttpResponse] = ??? | |
| * | |
| * // For such a task no more than 10 requests | |
| * // are allowed to be executed in parallel. | |
| * val task = semaphore.greenLight(makeRequest(???)) | |
| * }}} | |
| */ | |
| final class TaskSemaphore private (maxParallelism: Int) extends Serializable { | |
| import TaskSemaphore.State | |
| require(maxParallelism > 0, "parallelism > 0") | |
| private[this] val stateRef = | |
| Atomic.withPadding(TaskSemaphore.initialState, LeftRight128) | |
| /** Returns the number of active tasks that are holding on | |
| * to the available permits. | |
| */ | |
| def activeCount: Int = | |
| stateRef.get.activeCount | |
| /** Creates a new task ensuring that the given source | |
| * acquires an available permit from the semaphore before | |
| * it is being executed. | |
| * | |
| * The returned task also takes care of resource handling, | |
| * releasing its permit after being complete. | |
| */ | |
| def greenLight[A](fa: Task[A]): Task[A] = | |
| Task.unsafeCreate { (s, conn, cb) => | |
| val permit = acquire() | |
| val c = Cancelable(release) | |
| // On cancel trigger a release | |
| conn.push(c) | |
| val source = Task.fromFuture(permit).flatMap { _ => | |
| // On finish trigger a release | |
| fa.doOnFinish(_ => Task.eval(c.cancel())) | |
| } | |
| Task.unsafeStartNow(source, s, conn, cb) | |
| } | |
| /** Internal. Releases a permit, returning it to the pool. */ | |
| @tailrec private def release(): Unit = | |
| stateRef.get match { | |
| case current @ State(activeCount, promises) => | |
| val (p, newPromises) = | |
| if (promises.nonEmpty) promises.dequeue else (null, promises) | |
| val newActiveCount = | |
| if (p != null) activeCount else activeCount - 1 | |
| val update = | |
| State(newActiveCount, newPromises) | |
| if (!stateRef.compareAndSet(current, update)) | |
| release() // retry | |
| else if (p != null) | |
| p.trySuccess(()) | |
| } | |
| /** Internal. Triggers a permit acquisition, | |
| * returning a future that will complete when a | |
| * permit gets acquired. | |
| */ | |
| @tailrec private def acquire(): Future[Unit] = | |
| stateRef.get match { | |
| case current @ State(activeCount, _) => | |
| if (activeCount < maxParallelism) { | |
| val update = current.activateOne() | |
| if (!stateRef.compareAndSet(current, update)) | |
| acquire() // retry | |
| else | |
| TaskSemaphore.availablePermit | |
| } | |
| else { | |
| val p = Promise[Unit]() | |
| val update = current.addPromise(p) | |
| if (!stateRef.compareAndSet(current, update)) | |
| acquire() // retry | |
| else | |
| p.future | |
| } | |
| } | |
| } | |
| object TaskSemaphore { | |
| /** Builder for [[TaskSemaphore]]. | |
| * | |
| * @param maxParallelism represents the number of tasks allowed for | |
| * parallel execution | |
| */ | |
| def apply(maxParallelism: Int): TaskSemaphore = | |
| new TaskSemaphore(maxParallelism) | |
| private final val availablePermit = | |
| Future.successful(()) | |
| private final val initialState: State = | |
| State(0, Queue.empty) | |
| private final case class State( | |
| activeCount: Int, | |
| promises: Queue[Promise[Unit]]) { | |
| def activateOne(): State = | |
| copy(activeCount = activeCount + 1) | |
| def addPromise(p: Promise[Unit]): State = | |
| copy(promises = promises.enqueue(p)) | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment