-
-
Save gigamonkey/6479430 to your computer and use it in GitHub Desktop.
What I came up with while mucking around trying to understand your code.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Run this with scala <filename> | |
import java.util.concurrent.atomic.AtomicLong | |
val txIds = new AtomicLong(0) | |
// Dummied up transaction id provider | |
def nextTxId = txIds.incrementAndGet | |
/** | |
* A Two-phase commit Monad. | |
*/ | |
trait Transaction[+T] { | |
def map[U](fn: T => U): Transaction[U] = flatMap { t => Constant(fn(t)) } | |
def flatMap[U](fn: T => Transaction[U]): Transaction[U] = FlatMapped(this, fn) | |
// In a real system, this should be wrapped in a Future. | |
def prepare(id: Long): Either[Transaction[T], Prepared[T]] | |
final def run(txId: Long): T = { | |
@annotation.tailrec | |
def go(id: Long, tx: Transaction[T]): T = | |
tx.prepare(id) match { | |
case Left(toRetry) => go(nextTxId, toRetry) | |
case Right(prepped) => prepped.commit.get | |
} | |
go(txId, this) | |
} | |
} | |
/** | |
* Represents the state when commit is possible. A Prepared's value is | |
* the value that would be committed. In a real system commit and | |
* rollback should return Futures of the types they return here. | |
*/ | |
trait Prepared[+T] { | |
def value: T | |
def commit: Committed[T] | |
def rollback: Transaction[T] | |
} | |
/** | |
* A Type wrapper just used to mark success. | |
*/ | |
case class Committed[+T](get: T) | |
// The most trivial instance | |
case class Constant[+T](get: T) extends Transaction[T] { self => | |
def prepare(id: Long) = Right(new Prepared[T] { | |
def value = get | |
def commit = Committed(get) | |
def rollback = self | |
}) | |
} | |
/** Implementation of flatMap (non-trampolined, so super big transactions can fail) */ | |
case class FlatMapped[R,T](init: Transaction[R], fn: R => Transaction[T]) extends Transaction[T] { | |
def prepare(id: Long) = init.prepare(id) match { | |
case Left(ninit) => | |
// Couldn't even prepare the transaction that was going to provide our value. | |
Left(FlatMapped(ninit, fn)) | |
case Right(iprep) => | |
val next = fn(iprep.value) | |
next.prepare(id) match { | |
case Left(_) => Left(FlatMapped(iprep.rollback, fn)) | |
case Right(rprep) => Right(new Prepared[T] { | |
def value = rprep.value | |
def commit = { | |
// See how we need a Future (or some Monad) to sequence these | |
iprep.commit | |
rprep.commit | |
} | |
lazy val rollback = { | |
// See how we need a Future (or some Monad) to sequence these | |
rprep.rollback | |
FlatMapped(iprep.rollback, fn) | |
} | |
}) | |
} | |
} | |
} | |
import java.util.concurrent.atomic.{AtomicReference => Atom} | |
// The value of our atomic ref is either a legit value in a Right or a | |
// prepared value marked with a transaction id in a Left. | |
type AtomicState[T] = Either[(Long, T, T), T] | |
// Wrapper for an atomic-ref | |
class Atomic[T](value: T) { | |
private val state = new Atom[AtomicState[T]](Right(value)) | |
// The two things we can do: read and modify | |
def read: Transaction[T] = new AtomicAction(state, identity[T]) | |
def modify(fn: T => T): Transaction[T] = new AtomicAction(state, fn) | |
} | |
class AtomicAction[T](atom: Atom[AtomicState[T]], fn: T => T) extends Transaction[T] { | |
def prepare(id: Long) = | |
atom.get match { | |
case Left((oldId, old, nu)) if (oldId != id) => Left(new AtomicAction(atom, fn)) | |
case expected@Right(old) => claimForTx(expected, id, old, fn(old)) | |
case expected@Left((oldId, old, nu)) => claimForTx(expected, id, old, fn(nu)) | |
} | |
def claimForTx(expectedState: AtomicState[T], id: Long, old: T, nu: T) = { | |
// Our attempt to claim can fail if some other tx gets in before | |
// this compareAndSet | |
if (atom.compareAndSet(expectedState, Left((id, old, nu)))) | |
Right(new Prepared[T] { | |
def value = nu | |
lazy val commit = { | |
atom.get match { | |
case state@Left((thisId, old, nu)) if (thisId == id) => | |
atom.compareAndSet(state, Right(nu)) | |
case _ => () | |
} | |
Committed(value) | |
} | |
lazy val rollback = { | |
atom.get match { | |
case state@Left((thisId, old, nu)) if (thisId == id) => | |
atom.compareAndSet(state, Right(old)) | |
case _ => () | |
} | |
new AtomicAction(atom, fn) | |
} | |
}) | |
else | |
Left(new AtomicAction(atom, fn)) | |
} | |
} | |
object Test { | |
def test(count: Int) = { | |
val mem = new Atomic[Int](42) | |
/** Thread that just does subtraction */ | |
val sub = new Thread { | |
override def run { | |
@annotation.tailrec | |
def go(cnt: Int, m: Int): Int = { | |
if (cnt > 0) { | |
// Compose monadically -- the whole point of this exercise. | |
val tx = for { | |
x <- mem.read | |
y <- mem.modify { _ => x - 1} | |
} yield y | |
go(cnt - 1, m max tx.run(nextTxId)) | |
} | |
else m | |
} | |
println("max seen by subber: " + go(count, 42)) | |
} | |
} | |
/** Thread that just does addition */ | |
val add = new Thread { | |
override def run { | |
@annotation.tailrec | |
def go(cnt: Int, m: Int): Int = { | |
if(cnt > 0) { | |
// Again, compose monadically | |
val tx = for { | |
x <- mem.read | |
y <- mem.modify { _ => x + 1} | |
} yield y | |
go(cnt - 1, m min tx.run(nextTxId)) | |
} | |
else | |
m | |
} | |
println("min seen by adder: " + go(count, 42)) | |
} | |
} | |
sub.start | |
add.start | |
sub.join | |
add.join | |
println("42 is always the answer (no race conditions) ==> " + mem.read.run(-count)) | |
} | |
} | |
Test.test(1000000) | |
object Test2 { | |
var flakes = 10 | |
def flakey[T](tx: Transaction[T]) = { | |
new Transaction[T] { | |
def prepare(id: Long) = | |
if (flakes > 0) { | |
flakes -= 1 | |
Left(this) | |
} else { | |
tx.prepare(id) | |
} | |
} | |
} | |
def test() = { | |
val mem = new Atomic[Int](0) | |
var tx = for { | |
a <- mem.modify { _ + 1 } | |
b <- mem.modify { _ + 1 } | |
c <- mem.modify { _ + 1 } | |
d <- mem.modify { _ + 1 } | |
e <- flakey(mem.modify { _ + 1 }) | |
f <- mem.modify { _ + 1 } | |
} yield f | |
val got = tx.run(0) | |
println("got: " + got + "; expected: 6") | |
} | |
} | |
Test2.test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment