Skip to content

Instantly share code, notes, and snippets.

@ChristopherDavenport
Last active October 11, 2021 13:10
Show Gist options
  • Save ChristopherDavenport/3e2112eb3cc4683e812cc52024f22ab8 to your computer and use it in GitHub Desktop.
Save ChristopherDavenport/3e2112eb3cc4683e812cc52024f22ab8 to your computer and use it in GitHub Desktop.
Reentrant Lock
import cats._
import cats.syntax.all._
import cats.data._
import cats.effect._
import cats.effect.syntax.all._
import cats.effect.std.Semaphore
trait Lock[F[_]]{ self =>
def lock: F[Unit]
def unlock: F[Unit]
def permit: Resource[F, Unit]
def mapK[G[_]](fk: F ~> G)(implicit F: MonadCancel[F, _], G: MonadCancel[G, _]): Lock[G] =
new Lock[G]{
def lock: G[Unit] = fk(self.lock)
def unlock: G[Unit] = fk(self.unlock)
def permit: Resource[G,Unit] = self.permit.mapK(fk)
}
}
object Lock {
def simple[F[_]: Concurrent]: F[Lock[F]] =
Semaphore[F](1).map{s =>
new Lock[F]{
def lock: F[Unit] = s.acquire
def unlock: F[Unit] = s.release
def permit = s.permit
}
}
private case class Request[F[_]](unique: Unique.Token, gate: Deferred[F, Unit]) {
def sameUnique(that: Unique.Token) = that === unique
def sameUnique(that: Request[F]) = that.unique === unique
def wait_ = gate.get
def complete = gate.complete(())
}
private object Request {
def create[F[_]: Concurrent](token: Unique.Token): F[Request[F]] =
Deferred[F, Unit].map(Request(token, _))
}
import scala.collection.immutable.Queue
private case class State[F[_]](current: Option[Request[F]], waiting: Queue[Request[F]])
private class KleisliReentrantLock[F[_]: Concurrent](ref: Ref[F, State[F]]) extends Lock[Kleisli[F, Unique.Token, *]]{
def lock: Kleisli[F, Unique.Token, Unit] = Kleisli{(token: Unique.Token) =>
Concurrent[F].uncancelable{ (poll: Poll[F]) =>
Request.create(token).flatMap{request =>
ref.modify{
case s@State(Some(main), waiting) =>
if (request.sameUnique(main)) s -> Applicative[F].unit
else State(Some(main), waiting.enqueue(request)) ->
poll(request.gate.get).onCancel{
ref.update{
case s@State(Some(main), _) if request.sameUnique(main) =>
s
case State(s, wait2) =>
val wait = wait2.filterNot(request.sameUnique(_))
State(s, wait)
}
}
case State(None, _) => (State(Some(request), Queue.empty), Applicative[F].unit)
}.flatten
}
}
}
def unlock: Kleisli[F, Unique.Token, Unit] = Kleisli{(token: Unique.Token) =>
ref.modify{
case State(Some(current), waiters) if current.sameUnique(token) =>
waiters.dequeueOption match {
case Some((head, tail)) => State(Some(head), tail) -> head.gate.complete(()).void
case None => State(None, Queue.empty) -> Applicative[F].unit
}
case s@State(Some(holder), _) =>
s -> new Exception(s"Not The Current Holder of this lock held by: ${holder.unique} you are $token").raiseError[F, Unit]
case s@State(None, _) => s -> new Exception(s"Cannot Release a Lock you do not current hold").raiseError[F, Unit]
}.flatten.uncancelable
}
def permit: Resource[Kleisli[F, Unique.Token, *], Unit] =
Resource.make(lock)(_ => unlock)
}
def reentrant[F[_]: Concurrent]: F[Lock[Kleisli[F, Unique.Token, *]]] = {
Ref[F].of(State[F](None, Queue.empty)).map(new KleisliReentrantLock(_))
}
private def fromLocal[A](ioLocal: IOLocal[A]): Kleisli[IO, A, *] ~> IO = new (Kleisli[IO, A, *] ~> IO){
def apply[B](fa: Kleisli[IO,A,B]): IO[B] = ioLocal.get.flatMap(fa.run(_))
}
def ioLocal(ioLocal: IOLocal[Unique.Token]): IO[Lock[IO]] = reentrant[IO].map(lockK =>
lockK.mapK(fromLocal(ioLocal))
)
}
import cats._
import cats.syntax.all._
import cats.data._
import cats.effect._
import cats.effect.syntax.all._
import scala.collection.immutable.Queue
import scala.concurrent.duration.FiniteDuration
trait ReadWriteLock[F[_]]{ self =>
def readLock: Lock[F]
def writeLock: Lock[F]
def mapK[G[_]](fk: F ~> G)(implicit F: MonadCancel[F, _], G: MonadCancel[G, _]): ReadWriteLock[G] =
new ReadWriteLock[G] {
def readLock: Lock[G] = self.readLock.mapK(fk)
def writeLock: Lock[G] = self.writeLock.mapK(fk)
}
}
object ReadWriteLock {
private case class Request[F[_]](unique: Unique.Token, gate: Deferred[F, Unit]) {
def sameUnique(that: Unique.Token) = that === unique
def sameUnique(that: Request[F]) = that.unique === unique
def wait_ = gate.get
def complete = gate.complete(())
}
private object Request {
def create[F[_]: Concurrent](token: Unique.Token): F[Request[F]] =
Deferred[F, Unit].map(Request(token, _))
}
private sealed trait Current[F[_]]
private object Current {
case class Reads[F[_]](running: Queue[Request[F]]) extends Current[F]
case class Write[F[_]](running: Request[F]) extends Current[F]
}
private case class State[F[_]](
current: Option[Current[F]],
writeWaiting: Queue[Request[F]],
readWaiting: Queue[Request[F]]
)
class ReadWriteLockImpl[F[_]: Concurrent](ref: Ref[F, State[F]]) extends ReadWriteLock[Kleisli[F, Unique.Token, *]]{
class ReadLock extends Lock[Kleisli[F, Unique.Token, *]]{
def lock: Kleisli[F, Unique.Token, Unit] = Kleisli{(token: Unique.Token) =>
Request.create(token).flatMap{ req =>
Concurrent[F].uncancelable{(poll: Poll[F]) =>
ref.modify{
case State(None, _, _) =>
// No one with any locks
State(Current.Reads(Queue(req)).some, Queue.empty, Queue.empty) -> Applicative[F].unit
case s@State(Some(Current.Reads(queue)), writeQueue, _) if writeQueue.isEmpty && queue.exists(req.sameUnique(_)) =>
// Reentrant read
s -> Applicative[F].unit
case s@State(Some(Current.Write(write)), _, _) if write.sameUnique(req) =>
s -> Applicative[F].unit
case State(Some(Current.Reads(queue)), writeQueue, _) if writeQueue.isEmpty =>
// Current Reading with no waiting writes, continue
State(Current.Reads(queue.enqueue(req)).some, Queue.empty, Queue.empty) -> Applicative[F].unit
case s@State(Some(state), writeQueue, readQueue) =>
State(state.some, writeQueue, readQueue.enqueue(req)) ->
poll(req.wait_).onCancel{
ref.update{
case State(current, write, read) =>
val reads = read.filterNot(req.sameUnique)
State(current, write, reads)
}
}
}.flatten
}
}
}
// Unlock out of Reads favors Writes
def unlock: Kleisli[F, Unique.Token, Unit] = Kleisli{(token: Unique.Token) =>
ref.modify{
case State(Some(Current.Reads(queue)), writes, reads) =>
val newCurrentRead = queue.filterNot(_.sameUnique(token))
if (newCurrentRead.isEmpty){
writes.dequeueOption match {
case Some((head, tail)) =>
State(Current.Write(head).some, tail, reads) -> head.complete.void
case None => State(None, writes, reads) -> Applicative[F].unit
}
} else State(Current.Reads(newCurrentRead).some, writes, reads) -> Applicative[F].unit
case s@State(Some(Current.Write(req)), _, _) =>
s -> new Exception(s"Lock is held by write, cannot unlock a read").raiseError[F, Unit]
case s@State(None, _,_) =>
s -> new Exception(s"No Lock presently, cannot unlock when no lock is held").raiseError[F, Unit]
}.flatten.uncancelable
}
def permit: Resource[Kleisli[F, Unique.Token, *], Unit] = Resource.make(lock)(_ => unlock)
}
class WriteLock extends Lock[Kleisli[F, Unique.Token, *]]{
def lock: Kleisli[F, Unique.Token, Unit] = Kleisli{(token: Unique.Token) =>
Request.create(token).flatMap{ req =>
Concurrent[F].uncancelable{(poll: Poll[F]) =>
ref.modify{
case State(None, _, _) =>
// No one with any locks
State(Current.Write(req).some, Queue.empty, Queue.empty) -> Applicative[F].unit
case s@State(Some(Current.Write(r)), _, _) if r.sameUnique(req) =>
// Re-entrant write
s -> Applicative[F].unit
case s@State(Some(other), writes, reads) =>
State(other.some, writes.enqueue(req), reads) ->
poll(req.wait_)
.onCancel(
ref.update{
case State(current, writes, reads) =>
State(current, writes.filterNot(req.sameUnique), reads)
}
)
}
}
}
}
// Favor Batch Reads on Write Unlocks
def unlock: Kleisli[F, Unique.Token, Unit] = Kleisli{(token: Unique.Token) =>
ref.modify{
case State(Some(Current.Write(req)), writes, reads) if req.sameUnique(token) =>
if (reads.nonEmpty){
State(Current.Reads(reads).some, writes, Queue.empty) ->
reads.traverse(_.complete.void)
} else writes.dequeueOption match {
case Some((head, tail)) =>
State(Current.Write(head).some, tail, Queue.empty) -> head.complete.void
case None =>
State(None, Queue.empty, Queue.empty) -> Applicative[F].unit
}
case s@State(None, _, _) => s -> new Exception(s"Cannot unlock write lock when no lock is held").raiseError[F, Unit]
case s@State(Some(Current.Reads(_)), _, _) => s -> new Exception(s"Cannot Unlock Write Lock when Read holds lock").raiseError[F, Unit]
case s@State(Some(Current.Write(_)), _, _) => s -> new Exception(s"Another Write Holds Lock presently, cannot unlock").raiseError[F, Unit]
}
}
def permit: Resource[Kleisli[F, Unique.Token, *], Unit] = Resource.make(lock)(_ => unlock)
}
val readLock = new ReadLock
val writeLock = new WriteLock
}
def reentrant[F[_]: Concurrent]: F[ReadWriteLock[Kleisli[F, Unique.Token, *]]] =
Concurrent[F].ref(State[F](None, Queue.empty, Queue.empty)).map(
new ReadWriteLockImpl(_)
)
private def fromLocal[A](ioLocal: IOLocal[A]): Kleisli[IO, A, *] ~> IO = new (Kleisli[IO, A, *] ~> IO){
def apply[B](fa: Kleisli[IO,A,B]): IO[B] = ioLocal.get.flatMap(fa.run(_))
}
def ioLocal(ioLocal: IOLocal[Unique.Token]): IO[ReadWriteLock[IO]] = reentrant[IO].map(rwLockK =>
rwLockK.mapK(fromLocal(ioLocal))
)
}
object Test extends ResourceApp.Simple {
def run = for {
_ <- Resource.eval(IO(println(1)))
init <- Resource.eval(Lock.reentrant[IO])
token <- Resource.eval(Unique[IO].unique)
lock = init.mapK(Kleisli.applyK(token))
_ <- Resource.eval{
lock.permit.use(_ =>
lock.permit.use(_ =>
IO(println("Made it!"))
)
)
}
} yield ()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment