Skip to content

Instantly share code, notes, and snippets.

@davidpeklak
Last active August 29, 2015 14:11
Show Gist options
  • Save davidpeklak/4bf54412c92a8f0f0ed3 to your computer and use it in GitHub Desktop.
Save davidpeklak/4bf54412c92a8f0f0ed3 to your computer and use it in GitHub Desktop.
TransMem
import java.util.concurrent.{TimeUnit, Executors}
import scala.annotation.tailrec
import scalaz.{\/-, -\/, \/}
import java.util.concurrent.atomic.AtomicReference
object TM {
case class Baseline[A](baseline: A)
case class Transaction[A](baseline: Baseline[A], update: A)
/**
* Resolves concurrent changes to an A
* @tparam E: represents a failure to resolve
* @tparam A: the type of transactional data
* A: the current state of A
* Transaction[A]: A transactional change to A that happened in parallel
*/
type Resolve[E, A] = (A, Transaction[A]) => E \/ A
sealed trait TransMem[E, A] {
def get: Baseline[A]
def set(transaction: Transaction[A]): E \/ A
}
object TransMem {
def apply[E, A](init: A, resolve: Resolve[E, A]): TransMem[E, A] = new TransMem[E, A] {
val ra: AtomicReference[A] = new AtomicReference[A](init)
def get: Baseline[A] = Baseline(ra.get)
def set(transaction: Transaction[A]): E \/ A = {
@tailrec
def go(): E \/ A = {
val currentVal = ra.get
resolve(currentVal, transaction) match {
case fail@ -\/(e) => fail
case succ@ \/-(update) =>
if (ra.compareAndSet(currentVal, update)) succ
else go()
}
}
go()
}
}
}
}
object Usage {
import TM._
case class Counter(count: Int)
def counterResolve(current: Counter, trans: Transaction[Counter]): Unit \/ Counter = {
val diff = trans.update.count - trans.baseline.baseline.count
\/-(Counter(current.count + diff))
}
def changeCounter(ctm: TransMem[Unit, Counter], i: Int, sleep: Long) {
val baseline = ctm.get
Thread.sleep(sleep)
val update = Counter(baseline.baseline.count + i)
ctm.set(Transaction(baseline, update))
}
val ctm = TransMem(Counter(0), counterResolve)
val e = Executors.newFixedThreadPool(3)
def changeCounterAsync(i: Int, sleep: Long) {
e.execute(new Runnable {
def run() {
changeCounter(ctm, i, sleep)
}
})
}
def run() {
changeCounterAsync(2, 2000)
changeCounterAsync(-6, 2000)
changeCounterAsync(4, 2000)
changeCounterAsync(3, 2000)
changeCounterAsync(-4, 2000)
changeCounterAsync(-3, 2000)
changeCounterAsync(6, 2000)
changeCounterAsync(-2, 2000)
e.awaitTermination(2 * 8 + 1, TimeUnit.SECONDS)
println(ctm.get)
}
}
object ST {
import TransMem._
type StateTask[S, A] = StateT[Task, S, A]
def stateTaskNondet[S](resolve: Resolve[Throwable, S]): Nondeterminism[({type l[a] = StateTask[S, a]})#l] = new Nondeterminism[({type l[a] = StateTask[S, a]})#l] {
val M = StateT.stateTMonadState[S, Task]
val ND = implicitly[Nondeterminism[Task]]
def point[A](a: => A): StateTask[S, A] = M.point(a)
def bind[A, B](fa: StateTask[S, A])(f: (A) => StateTask[S, B]): StateTask[S, B] = M.bind(fa)(f)
def chooseAny[A](head: StateTask[S, A], tail: Seq[StateTask[S, A]]): StateTask[S, (A, Seq[StateTask[S, A]])] = ???
override def mapBoth[A, B, C](a: StateTask[S, A], b: StateTask[S, B])(f: (A, B) => C): StateTask[S, C] = {
StateT(s => {
val transS = TransMem[Throwable, S](s, resolve)
val baseline = transS.get
def wrapTask[X](t: StateTask[S, X]): Task[X] = {
for {
sx <- t(baseline.baseline)
(updateS, x) = sx
ts = transS.set(Transaction(baseline, updateS))
finalS <- ts match {
case -\/(thr) => Task.fail(thr)
case \/-(fs) => Task(fs)
}
} yield x
}
val ta: Task[A] = wrapTask(a)
val tb: Task[B] = wrapTask(b)
ND.mapBoth(ta, tb)(f).map(c => (transS.get.baseline, c))
})
}
}
}
object STUsage {
import TransMem._
import ST._
def resolve(current: List[String], trans: Transaction[List[String]]): Throwable \/ List[String] = {
val revCurrent = current.reverse
val revBaseline = trans.baseline.baseline.reverse
val revUpdate = trans.update.reverse
val equalCount = (revCurrent zip revBaseline).takeWhile(t => t._1 == t._2).size
val revResult = revCurrent ++ revUpdate.drop(equalCount)
\/-(revResult.reverse)
}
val nondet = stateTaskNondet[List[String]](resolve)
val wt3: StateTask[List[String], Int] = StateT( l => Task({
("Procude the 3" :: l, 3)
}))
def add(a: Int, b: Int, sleep: Long): StateTask[List[String], Int] = StateT( l => {
def theWork(): (List[String], Int) = {
println(s"Waiting to add $a to $b.")
Thread.sleep(sleep)
println(s"Adding $a to $b.")
(s"Adding $a to $b." :: l , a + b)
}
Task(theWork())
})
val work2 = for {
a <- wt3
bc <- nondet.both(add(a, 3, 3000), add(a, 6, 100))
(b, c) = bc
} yield {
b + c
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment