Skip to content

Instantly share code, notes, and snippets.

@ChristopherDavenport
Created January 22, 2022 03:21
Show Gist options
  • Save ChristopherDavenport/badb298201829a64c88a523c69aa3979 to your computer and use it in GitHub Desktop.
Save ChristopherDavenport/badb298201829a64c88a523c69aa3979 to your computer and use it in GitHub Desktop.
UnsafeDeferred - Unsafe Set, Safe Get Deferred.
import scala.annotation.tailrec
import scala.collection.immutable.LongMap
import java.util.concurrent.atomic.AtomicReference
import cats.effect._
import cats.syntax.all._
trait UnsafeDeferred[F[_], A] extends cats.effect.kernel.DeferredSource[F, A]{
def complete(a: A): Boolean
}
object UnsafeDeferred {
sealed abstract private class State[A]
private object State {
final case class Set[A](a: A) extends State[A]
final case class Unset[A](readers: LongMap[A => Unit], nextId: Long) extends State[A]
val initialId = 1L
val dummyId = 0L
}
final class AsyncDeferred[F[_], A](implicit F: Async[F]) extends UnsafeDeferred[F, A] {
// shared mutable state
private[this] val ref = new AtomicReference[State[A]](
State.Unset(LongMap.empty, State.initialId)
)
def get: F[A] = {
// side-effectful
def addReader(awakeReader: A => Unit): Long = {
@tailrec
def loop(): Long =
ref.get match {
case State.Set(a) =>
awakeReader(a)
State.dummyId // never used
case s @ State.Unset(readers, nextId) =>
val updated = State.Unset(
readers + (nextId -> awakeReader),
nextId + 1
)
if (!ref.compareAndSet(s, updated)) loop()
else nextId
}
loop()
}
// side-effectful
def deleteReader(id: Long): Unit = {
@tailrec
def loop(): Unit =
ref.get match {
case State.Set(_) => ()
case s @ State.Unset(readers, _) =>
val updated = s.copy(readers = readers - id)
if (!ref.compareAndSet(s, updated)) loop()
else ()
}
loop()
}
F.defer {
ref.get match {
case State.Set(a) =>
F.pure(a)
case State.Unset(_, _) =>
F.async[A] { cb =>
val resume = (a: A) => cb(Right(a))
F.delay(addReader(awakeReader = resume)).map { id =>
// if canceled
F.delay(deleteReader(id)).some
}
}
}
}
}
def tryGet: F[Option[A]] =
F.delay {
ref.get match {
case State.Set(a) => Some(a)
case State.Unset(_, _) => None
}
}
def complete(a: A): Boolean = {
def notifyReaders(readers: LongMap[A => Unit]): Unit = {
// LongMap iterators return values in unsigned key order,
// which corresponds to the arrival order of readers since
// insertion is governed by a monotonically increasing id
val cursor = readers.valuesIterator
while (cursor.hasNext) {
val next = cursor.next()
next(a)
}
}
// side-effectful (even though it returns F[Unit])
@tailrec
def loop(): Boolean =
ref.get match {
case State.Set(_) =>
false
case s @ State.Unset(readers, _) =>
val updated = State.Set(a)
if (!ref.compareAndSet(s, updated)) loop()
else {
if (readers.isEmpty) () else notifyReaders(readers)
true
}
}
loop()
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment