Last active
August 29, 2015 14:11
-
-
Save davidpeklak/4bf54412c92a8f0f0ed3 to your computer and use it in GitHub Desktop.
TransMem
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.concurrent.{TimeUnit, Executors} | |
import scala.annotation.tailrec | |
import scalaz.{\/-, -\/, \/} | |
import java.util.concurrent.atomic.AtomicReference | |
object TM { | |
case class Baseline[A](baseline: A) | |
case class Transaction[A](baseline: Baseline[A], update: A) | |
/** | |
* Resolves concurrent changes to an A | |
* @tparam E: represents a failure to resolve | |
* @tparam A: the type of transactional data | |
* A: the current state of A | |
* Transaction[A]: A transactional change to A that happened in parallel | |
*/ | |
type Resolve[E, A] = (A, Transaction[A]) => E \/ A | |
sealed trait TransMem[E, A] { | |
def get: Baseline[A] | |
def set(transaction: Transaction[A]): E \/ A | |
} | |
object TransMem { | |
def apply[E, A](init: A, resolve: Resolve[E, A]): TransMem[E, A] = new TransMem[E, A] { | |
val ra: AtomicReference[A] = new AtomicReference[A](init) | |
def get: Baseline[A] = Baseline(ra.get) | |
def set(transaction: Transaction[A]): E \/ A = { | |
@tailrec | |
def go(): E \/ A = { | |
val currentVal = ra.get | |
resolve(currentVal, transaction) match { | |
case fail@ -\/(e) => fail | |
case succ@ \/-(update) => | |
if (ra.compareAndSet(currentVal, update)) succ | |
else go() | |
} | |
} | |
go() | |
} | |
} | |
} | |
} | |
object Usage { | |
import TM._ | |
case class Counter(count: Int) | |
def counterResolve(current: Counter, trans: Transaction[Counter]): Unit \/ Counter = { | |
val diff = trans.update.count - trans.baseline.baseline.count | |
\/-(Counter(current.count + diff)) | |
} | |
def changeCounter(ctm: TransMem[Unit, Counter], i: Int, sleep: Long) { | |
val baseline = ctm.get | |
Thread.sleep(sleep) | |
val update = Counter(baseline.baseline.count + i) | |
ctm.set(Transaction(baseline, update)) | |
} | |
val ctm = TransMem(Counter(0), counterResolve) | |
val e = Executors.newFixedThreadPool(3) | |
def changeCounterAsync(i: Int, sleep: Long) { | |
e.execute(new Runnable { | |
def run() { | |
changeCounter(ctm, i, sleep) | |
} | |
}) | |
} | |
def run() { | |
changeCounterAsync(2, 2000) | |
changeCounterAsync(-6, 2000) | |
changeCounterAsync(4, 2000) | |
changeCounterAsync(3, 2000) | |
changeCounterAsync(-4, 2000) | |
changeCounterAsync(-3, 2000) | |
changeCounterAsync(6, 2000) | |
changeCounterAsync(-2, 2000) | |
e.awaitTermination(2 * 8 + 1, TimeUnit.SECONDS) | |
println(ctm.get) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object ST { | |
import TransMem._ | |
type StateTask[S, A] = StateT[Task, S, A] | |
def stateTaskNondet[S](resolve: Resolve[Throwable, S]): Nondeterminism[({type l[a] = StateTask[S, a]})#l] = new Nondeterminism[({type l[a] = StateTask[S, a]})#l] { | |
val M = StateT.stateTMonadState[S, Task] | |
val ND = implicitly[Nondeterminism[Task]] | |
def point[A](a: => A): StateTask[S, A] = M.point(a) | |
def bind[A, B](fa: StateTask[S, A])(f: (A) => StateTask[S, B]): StateTask[S, B] = M.bind(fa)(f) | |
def chooseAny[A](head: StateTask[S, A], tail: Seq[StateTask[S, A]]): StateTask[S, (A, Seq[StateTask[S, A]])] = ??? | |
override def mapBoth[A, B, C](a: StateTask[S, A], b: StateTask[S, B])(f: (A, B) => C): StateTask[S, C] = { | |
StateT(s => { | |
val transS = TransMem[Throwable, S](s, resolve) | |
val baseline = transS.get | |
def wrapTask[X](t: StateTask[S, X]): Task[X] = { | |
for { | |
sx <- t(baseline.baseline) | |
(updateS, x) = sx | |
ts = transS.set(Transaction(baseline, updateS)) | |
finalS <- ts match { | |
case -\/(thr) => Task.fail(thr) | |
case \/-(fs) => Task(fs) | |
} | |
} yield x | |
} | |
val ta: Task[A] = wrapTask(a) | |
val tb: Task[B] = wrapTask(b) | |
ND.mapBoth(ta, tb)(f).map(c => (transS.get.baseline, c)) | |
}) | |
} | |
} | |
} | |
object STUsage { | |
import TransMem._ | |
import ST._ | |
def resolve(current: List[String], trans: Transaction[List[String]]): Throwable \/ List[String] = { | |
val revCurrent = current.reverse | |
val revBaseline = trans.baseline.baseline.reverse | |
val revUpdate = trans.update.reverse | |
val equalCount = (revCurrent zip revBaseline).takeWhile(t => t._1 == t._2).size | |
val revResult = revCurrent ++ revUpdate.drop(equalCount) | |
\/-(revResult.reverse) | |
} | |
val nondet = stateTaskNondet[List[String]](resolve) | |
val wt3: StateTask[List[String], Int] = StateT( l => Task({ | |
("Procude the 3" :: l, 3) | |
})) | |
def add(a: Int, b: Int, sleep: Long): StateTask[List[String], Int] = StateT( l => { | |
def theWork(): (List[String], Int) = { | |
println(s"Waiting to add $a to $b.") | |
Thread.sleep(sleep) | |
println(s"Adding $a to $b.") | |
(s"Adding $a to $b." :: l , a + b) | |
} | |
Task(theWork()) | |
}) | |
val work2 = for { | |
a <- wt3 | |
bc <- nondet.both(add(a, 3, 3000), add(a, 6, 100)) | |
(b, c) = bc | |
} yield { | |
b + c | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment