Skip to content

Instantly share code, notes, and snippets.

@johnynek
Last active October 2, 2020 22:36
Show Gist options
  • Save johnynek/6473724 to your computer and use it in GitHub Desktop.
Save johnynek/6473724 to your computer and use it in GitHub Desktop.
A two-phase commit Transaction Monad. See: http://en.wikipedia.org/wiki/Two-phase_commit_protocol
// Run this with scala <filename>
/**
* A Two-phase commit Monad
*/
trait Transaction[+T] {
def map[U](fn: T => U): Transaction[U] = flatMap { t => Constant(fn(t)) }
def flatMap[U](fn: T => Transaction[U]): Transaction[U] =
FlatMapped(this, fn)
// In a real system, this should be wrapped in a Future
def query(id: Long): Either[Transaction[T], Prepared[T]]
final def run(inid: Long): T = {
@annotation.tailrec
def go(id: Long, eith: Either[Transaction[T], Prepared[T]]): T = eith match {
case Left(trans) => go(id + 1L, trans.query(id + 1L))
case Right(prepped) => prepped.commit.get
}
go(inid, query(inid))
}
}
/** Represents the state when commit is possible */
trait Prepared[+T] {
def init: T
// In a real system, these should be wrapped in Futures
def commit: Committed[T]
def rollback: Transaction[T]
}
/** A Type wrapper just used to mark success. */
case class Committed[+T](get: T)
// The most trivial instance
case class Constant[+T](get: T) extends Transaction[T] { self =>
def query(l: Long) = Right(new Prepared[T] {
def init = get
def commit = Committed(get)
def rollback = self
})
}
/** Implementation of flatMap (non-trampolined, so super big transactions can fail) */
case class FlatMapped[R,T](init: Transaction[R], fn: R => Transaction[T]) extends Transaction[T] {
def query(id: Long) = init.query(id) match {
case Left(ninit) =>
Left(FlatMapped(ninit, fn)) // didn't work
case Right(iprep) =>
val next = fn(iprep.init)
next.query(id) match {
case Left(nretry) => Left(FlatMapped(iprep.rollback, fn))
case Right(rprep) => Right(new Prepared[T] {
def init = rprep.init
def commit = {
// See how we need a Future (or some Monad) to sequence these
iprep.commit
rprep.commit
}
lazy val rollback = {
// See how we need a Future (or some Monad) to sequence these
rprep.rollback
FlatMapped(iprep.rollback, fn)
}
})
}
}
}
import java.util.concurrent.atomic.{AtomicReference => Atom}
// Wrapper for an atomic-ref
class Atomic[T](init: T) {
private val state = new Atom[Either[(Long, T, T), T]](Right(init))
// The two things we can do: read and alter
def read: Transaction[T] = new Map(state, identity[T])
def map(fn: T => T): Transaction[T] = new Map(state, fn)
}
class Map[T](atom: Atom[Either[(Long, T, T), T]], fn: T => T) extends Transaction[T] {
@annotation.tailrec
private def unset(id: Long): Unit =
atom.get match {
case Right(_) => ()
case left@Left((thisId, t, _)) =>
if(thisId != id) ()
else if(atom.compareAndSet(left, Right(t))) ()
else unset(id)
}
private def rollbackId(id: Long): Unit =
atom.get match {
case Right(_) => ()
case left@Left((thisId, _, rb)) =>
if(thisId != id) ()
else if(atom.compareAndSet(left, Right(rb))) ()
else rollbackId(id)
}
private def prep(id: Long, myRes: T) =
Right(new Prepared[T] {
def init = myRes
lazy val commit = {
unset(id)
Committed(myRes)
}
lazy val rollback = {
rollbackId(id)
new Map(atom, fn)
}
})
private def tryApply(eith: Either[(Long, T, T), T], id: Long, t: T, rollbackv: T) = {
val myRes = fn(t)
if(atom.compareAndSet(eith, Left((id, myRes, rollbackv))))
prep(id, myRes)
else
Left(new Map(atom, fn))
}
def query(id: Long) = {
atom.get match {
case r@Right(t) => tryApply(r, id, t, t) // start of request
case left@Left((existingId, before, rb)) =>
if(existingId == id) {
// Part of the same request
tryApply(left, id, before, rb)
}
else Left(new Map(atom, fn))
}
}
}
object Test {
def test(count: Int) = {
val mem = new Atomic[Int](42)
/** Thread that just does subtraction */
val sub = new Thread {
override def run {
@annotation.tailrec
def go(cnt: Int, m: Int): Int = {
if(cnt > 0) {
// Compose monadically
val r = (for {
x <- mem.read
y <- mem.map { _ => x - 1}
} yield y).run(cnt) // make sure the id does not collide with add
go(cnt - 1, m max r)
}
else
m
}
println("max seen by subber: " + go(count, 42))
}
}
/** Thread that just does addition */
val add = new Thread {
override def run {
@annotation.tailrec
def go(cnt: Int, m: Int): Int = {
if(cnt > 0) {
// Compose monadically
val r = (for {
x <- mem.read
y <- mem.map { _ => x + 1}
} yield y).run(2*count + cnt) // make sure the id does not collide with sub
go(cnt - 1, r min m)
}
else
m
}
println("min seen by adder: " + go(count, 42))
}
}
sub.start
add.start
sub.join
add.join
println("42 is always the answer (no race conditions) ==> " + mem.read.run(-count))
}
}
Test.test(1000000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment