Last active
October 2, 2020 22:36
-
-
Save johnynek/6473724 to your computer and use it in GitHub Desktop.
A two-phase commit Transaction Monad. See: http://en.wikipedia.org/wiki/Two-phase_commit_protocol
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Run this with scala <filename> | |
/** | |
* A Two-phase commit Monad | |
*/ | |
trait Transaction[+T] { | |
def map[U](fn: T => U): Transaction[U] = flatMap { t => Constant(fn(t)) } | |
def flatMap[U](fn: T => Transaction[U]): Transaction[U] = | |
FlatMapped(this, fn) | |
// In a real system, this should be wrapped in a Future | |
def query(id: Long): Either[Transaction[T], Prepared[T]] | |
final def run(inid: Long): T = { | |
@annotation.tailrec | |
def go(id: Long, eith: Either[Transaction[T], Prepared[T]]): T = eith match { | |
case Left(trans) => go(id + 1L, trans.query(id + 1L)) | |
case Right(prepped) => prepped.commit.get | |
} | |
go(inid, query(inid)) | |
} | |
} | |
/** Represents the state when commit is possible */ | |
trait Prepared[+T] { | |
def init: T | |
// In a real system, these should be wrapped in Futures | |
def commit: Committed[T] | |
def rollback: Transaction[T] | |
} | |
/** A Type wrapper just used to mark success. */ | |
case class Committed[+T](get: T) | |
// The most trivial instance | |
case class Constant[+T](get: T) extends Transaction[T] { self => | |
def query(l: Long) = Right(new Prepared[T] { | |
def init = get | |
def commit = Committed(get) | |
def rollback = self | |
}) | |
} | |
/** Implementation of flatMap (non-trampolined, so super big transactions can fail) */ | |
case class FlatMapped[R,T](init: Transaction[R], fn: R => Transaction[T]) extends Transaction[T] { | |
def query(id: Long) = init.query(id) match { | |
case Left(ninit) => | |
Left(FlatMapped(ninit, fn)) // didn't work | |
case Right(iprep) => | |
val next = fn(iprep.init) | |
next.query(id) match { | |
case Left(nretry) => Left(FlatMapped(iprep.rollback, fn)) | |
case Right(rprep) => Right(new Prepared[T] { | |
def init = rprep.init | |
def commit = { | |
// See how we need a Future (or some Monad) to sequence these | |
iprep.commit | |
rprep.commit | |
} | |
lazy val rollback = { | |
// See how we need a Future (or some Monad) to sequence these | |
rprep.rollback | |
FlatMapped(iprep.rollback, fn) | |
} | |
}) | |
} | |
} | |
} | |
import java.util.concurrent.atomic.{AtomicReference => Atom} | |
// Wrapper for an atomic-ref | |
class Atomic[T](init: T) { | |
private val state = new Atom[Either[(Long, T, T), T]](Right(init)) | |
// The two things we can do: read and alter | |
def read: Transaction[T] = new Map(state, identity[T]) | |
def map(fn: T => T): Transaction[T] = new Map(state, fn) | |
} | |
class Map[T](atom: Atom[Either[(Long, T, T), T]], fn: T => T) extends Transaction[T] { | |
@annotation.tailrec | |
private def unset(id: Long): Unit = | |
atom.get match { | |
case Right(_) => () | |
case left@Left((thisId, t, _)) => | |
if(thisId != id) () | |
else if(atom.compareAndSet(left, Right(t))) () | |
else unset(id) | |
} | |
private def rollbackId(id: Long): Unit = | |
atom.get match { | |
case Right(_) => () | |
case left@Left((thisId, _, rb)) => | |
if(thisId != id) () | |
else if(atom.compareAndSet(left, Right(rb))) () | |
else rollbackId(id) | |
} | |
private def prep(id: Long, myRes: T) = | |
Right(new Prepared[T] { | |
def init = myRes | |
lazy val commit = { | |
unset(id) | |
Committed(myRes) | |
} | |
lazy val rollback = { | |
rollbackId(id) | |
new Map(atom, fn) | |
} | |
}) | |
private def tryApply(eith: Either[(Long, T, T), T], id: Long, t: T, rollbackv: T) = { | |
val myRes = fn(t) | |
if(atom.compareAndSet(eith, Left((id, myRes, rollbackv)))) | |
prep(id, myRes) | |
else | |
Left(new Map(atom, fn)) | |
} | |
def query(id: Long) = { | |
atom.get match { | |
case r@Right(t) => tryApply(r, id, t, t) // start of request | |
case left@Left((existingId, before, rb)) => | |
if(existingId == id) { | |
// Part of the same request | |
tryApply(left, id, before, rb) | |
} | |
else Left(new Map(atom, fn)) | |
} | |
} | |
} | |
object Test { | |
def test(count: Int) = { | |
val mem = new Atomic[Int](42) | |
/** Thread that just does subtraction */ | |
val sub = new Thread { | |
override def run { | |
@annotation.tailrec | |
def go(cnt: Int, m: Int): Int = { | |
if(cnt > 0) { | |
// Compose monadically | |
val r = (for { | |
x <- mem.read | |
y <- mem.map { _ => x - 1} | |
} yield y).run(cnt) // make sure the id does not collide with add | |
go(cnt - 1, m max r) | |
} | |
else | |
m | |
} | |
println("max seen by subber: " + go(count, 42)) | |
} | |
} | |
/** Thread that just does addition */ | |
val add = new Thread { | |
override def run { | |
@annotation.tailrec | |
def go(cnt: Int, m: Int): Int = { | |
if(cnt > 0) { | |
// Compose monadically | |
val r = (for { | |
x <- mem.read | |
y <- mem.map { _ => x + 1} | |
} yield y).run(2*count + cnt) // make sure the id does not collide with sub | |
go(cnt - 1, r min m) | |
} | |
else | |
m | |
} | |
println("min seen by adder: " + go(count, 42)) | |
} | |
} | |
sub.start | |
add.start | |
sub.join | |
add.join | |
println("42 is always the answer (no race conditions) ==> " + mem.read.run(-count)) | |
} | |
} | |
Test.test(1000000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment