Created
March 20, 2019 08:41
-
-
Save monadplus/90c56dc235717e68eefe30c626c3f55b to your computer and use it in GitHub Desktop.
Exercise: implement cats.Deferred from scratches
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Original code from https://github.com/typelevel/cats-effect/blob/1846813109b1e78c5bf36e6e179d7a91419e01d0/core/shared/src/main/scala/cats/effect/concurrent/Deferred.scala#L164 | |
// Exercise: implement deferred | |
// Bear in mind: | |
// - always use compareAndSet when updating the atomic reference | |
// - don't break RT when dealing with ref (i.e. F.delay, F.suspend) | |
// - `get` should be cancelable. | |
// - `complete` should return a F[Unit] that does not block the current ec. | |
// - `register` and `unregister` methods should be tail-recursive. | |
// Helpers | |
private final class Id | |
private sealed abstract class State[A] | |
private object State { | |
final case class Set[A](a: A) extends State[A] | |
// a => cb(Right(a)) | |
final case class Unset[A](waiting: LinkedMap[Id, A => Unit]) extends State[A] | |
} | |
private final class ConcurrentDeferred[F[_], A](ref: AtomicReference[State[A]])(implicit F: Concurrent[F]) | |
extends TryableDeferred[F, A] { | |
// Return set value or wait for the deferred to be completed | |
// Hint: use F.cancelable and register cb in ref | |
override def get: F[A] = ??? | |
// Don't wait, return None if deferred is not completed | |
override def tryGet: F[Option[A]] = ??? | |
// throw if already completed, otherwise complete all waiting tasks | |
override def complete(a: A): F[Unit] = ??? | |
} | |
// Solution: | |
private final class ConcurrentDeferred[F[_], A](ref: AtomicReference[State[A]])(implicit F: Concurrent[F]) | |
extends TryableDeferred[F, A] { | |
override def get: F[A] = | |
F.suspend { | |
ref.get() match { | |
case State.Set(a) => | |
F.pure(a) | |
case State.Unset(_) => | |
F.cancelable { cb => | |
val id = unsafeRegister(cb) | |
@tailrec | |
def unregister(): Unit = | |
ref.get() match { | |
case State.Set(_) => () | |
case s @ State.Unset(waiting) => | |
val updated = State.Unset(waiting - id) | |
if (ref.compareAndSet(s, updated)) () | |
else unregister() | |
} | |
F.delay(unregister()) | |
} | |
} | |
} | |
override def tryGet: F[Option[A]] = | |
F.delay { | |
ref.get() match { | |
case State.Set(a) => | |
Some(a) | |
case State.Unset(_) => | |
None | |
} | |
} | |
override def complete(a: A): F[Unit] = | |
F.suspend(unsafeComplete(a)) | |
@tailrec | |
private[this] def unsafeComplete(a: A): F[Unit] = | |
ref.get() match { | |
case State.Set(_) => | |
throw new RuntimeException("Attempting to complete a Deferred that has already been completed") | |
case s @ State.Unset(waiting) => | |
if (ref.compareAndSet(s, State.Set(a))) { | |
val w = waiting.values | |
if (w.isEmpty) F.unit | |
else notifyReadersLoop(a, w) | |
} else { | |
unsafeComplete(a) | |
} | |
} | |
private[this] def notifyReadersLoop(a: A, iterable: Iterable[A => Unit]): F[Unit] = { | |
var acc: F[Unit] = F.unit | |
val it = iterable.toIterator | |
while(it.hasNext) { | |
val f = it.next() | |
val task = mapUnit(F.start(F.delay(f(a)))) | |
acc = F.flatMap(acc)(_ => task) | |
} | |
acc | |
} | |
private[this] def mapUnit[B](fb: F[B]): F[Unit] = F.map(fb)(_ => ()) | |
private[this] def unsafeRegister(cb: Either[Throwable, A] => Unit): Id = { | |
val id = new Id | |
@tailrec | |
def register(): Option[A] = { | |
ref.get() match { | |
case State.Set(a) => Some(a) | |
case w @ State.Unset(waiting) => | |
val updated = State.Unset(waiting.updated(id, (a: A)=> cb(Right(a)))) | |
if (ref.compareAndSet(w, updated)) None | |
else register() | |
} | |
} | |
register().foreach(a => cb(Right(a))) | |
id | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment