Last active
August 29, 2015 13:57
-
-
Save runarorama/9422453 to your computer and use it in GitHub Desktop.
Suggestion for the new representation of `IO` in Scalaz.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scalaz._ | |
import \/._ | |
import Free._ | |
import scalaz.syntax.monad._ | |
import java.util.concurrent.atomic.AtomicReference | |
import java.util.concurrent.CountDownLatch | |
object Experiment { | |
sealed trait OI[A] { | |
def map[B](f: A => B): OI[B] | |
} | |
case class Step[A](a: () => A) extends OI[A] { | |
def map[B](f: A => B): OI[B] = Step(() => f(a())) | |
} | |
case class Error[A](e: Throwable) extends OI[A] { | |
def map[B](f: A => B): OI[B] = Error(e) | |
} | |
case class Strat[A](k: Strategy => A) extends OI[A] { | |
def map[B](f: A => B): OI[B] = Strat(k andThen f) | |
} | |
case class Choose[Z,A](s: Strategy, | |
a: IO[Z], | |
as: Seq[IO[Z]], | |
k: (Z, Seq[IO[Z]]) => A) extends OI[A] { | |
def map[B](f: A => B): OI[B] = | |
Choose(s, a, as, (z:Z, zs: Seq[IO[Z]]) => f(k(z, zs))) | |
} | |
object OI { | |
implicit val oiFunctor: Functor[OI] = new Functor[OI] { | |
def map[A,B](oi: OI[A])(f: A => B): OI[B] = oi map f | |
} | |
} | |
// The representation of IO | |
private type Rep[A] = Free[OI, A] | |
class IO[A](private val rep: Rep[A]) { | |
import IO._ | |
def flatMap[B](f: A => IO[B]): IO[B] = | |
new IO(rep flatMap (a => f(a).rep)) | |
def map[B](f: A => B): IO[B] = | |
flatMap(a => Try(f(a))) | |
/** | |
* Run this and another task in parallel, combining their results with | |
* the given function. | |
*/ | |
def and[B,C](b: IO[B])(f: (A, B) => C): IO[C] = | |
new IO(Suspend(Strat(s => | |
Suspend(Choose[A \/ B, Rep[C]](s, | |
this.map(left(_)), Vector(b.map(right(_))), { | |
case (-\/(a), Seq(b)) => b.map(e => e.fold(_ => ???, f(a, _))).rep | |
case (\/-(b), Seq(a)) => a.map(e => e.fold(f(_, b), _ => ???)).rep | |
}))))) | |
def attempt: IO[Throwable \/ A] = { | |
def go(s: Free[OI, A]): Free[OI, Throwable \/ A] = | |
s.resume match { | |
case -\/(Error(e)) => Return(left(e)) | |
case -\/(s) => Suspend(s.map(go)) | |
case \/-(a) => Return(right(a)) | |
} | |
new IO(go(rep)) | |
} | |
def onFinish(f: Option[Throwable] => IO[Unit]): IO[A] = | |
attempt.flatMap { | |
case -\/(e) => f(Some(e)) *> fail(e) | |
case \/-(r) => f(None) *> now(r) | |
} | |
def handle(f: PartialFunction[Throwable, A]): IO[A] = | |
handleWith(f andThen now) | |
def handleWith(f: PartialFunction[Throwable, IO[A]]): IO[A] = | |
attempt flatMap { | |
case -\/(e) => f.lift(e) getOrElse fail(e) | |
case \/-(a) => now(a) | |
} | |
def or(t: IO[A]): IO[A] = | |
attempt flatMap { | |
case -\/(e) => t | |
case \/-(a) => now(a) | |
} | |
def attemptRun(s: Strategy = Strategy.DefaultStrategy): Throwable \/ A = | |
try right(this.run(s)) catch { case t: Throwable => left(t) } | |
/** UNSAFE! */ | |
/*final def listen(cb: (Throwable \/ A) => Trampoline[Unit]) | |
(S: Strategy = Strategy.DefaultStrategy): Trampoline[Unit] = { | |
def go(r: Rep[A]): Trampoline[Unit] = | |
stepRep(r).resume match { | |
case \/-(a) => cb(right(a)) | |
case -\/(Error(e)) => cb(left(e)) | |
case -\/(Step(s)) => go(s()) | |
case -\/(Async(k)) => k(r => go(r)) | |
case -\/(Strat(k)) => go(k(S)) | |
} | |
go(rep) | |
}*/ | |
import java.util.concurrent.atomic.AtomicBoolean | |
/** UNSAFE! */ | |
def stepInterruptibly(cancel: AtomicBoolean): IO[A] = { | |
def go(r: Rep[A]): Rep[A] = | |
if (!cancel.get) r.resume match { | |
case -\/(Step(thunk)) => go(thunk()) | |
case _ => r | |
} else r | |
new IO(go(rep)) | |
} | |
/** UNSAFE! */ | |
def step: IO[A] = new IO(stepRep(rep)) | |
/** UNSAFE! */ | |
def run(S: Strategy = Strategy.DefaultStrategy): A = rep.foldMap(new (OI ~> scalaz.Id.Id) { | |
def apply[T](oi: OI[T]): T = oi match { | |
case Choose(s, a, as, k) => | |
val latch = new CountDownLatch(1) | |
val won = new AtomicReference[Int](-1) | |
val rs = (a +: as).foldLeft((Vector[(() => Any, Int)](), 0)) { | |
case ((v, i), a) => | |
(v :+ ((s { | |
val r = a.run(s) | |
if (won.compareAndSet(-1, i)) | |
latch.countDown | |
r | |
}) -> i)) -> (i + 1) | |
}._1 | |
latch.await | |
k(rs(won.get)._1.apply(), | |
rs.flatMap { | |
case (a, i) => | |
if (i == won.get) Seq() | |
else Seq(new IO(Suspend(Step[Rep[Any]](() => Return(a()))))) | |
}) | |
case Step(k) => k() | |
case Error(e) => throw e | |
case Strat(k) => k(S) | |
} | |
}) | |
} | |
object IO { | |
/** Utility function - evaluate `a` and catch and return any exceptions. */ | |
def Try[A](a: => A): IO[A] = | |
try now(a) catch { case e: Throwable => fail(e) } | |
/** An `IO` that always fails with the given error. */ | |
def fail[A](e: Throwable): IO[A] = | |
new IO(Suspend[OI, A](Error(e))) | |
/** An `IO` that always succeeds with the given value */ | |
def now[A](a: A): IO[A] = | |
new IO(Return(a)) | |
/** An `IO` that evaluates the given expression. */ | |
def delay[A](a: => A): IO[A] = | |
new IO(Suspend(Step(() => Return(a)))) | |
/** | |
* An `IO` that suspends evaluation of the given `IO`. | |
* A trampolining primitive, helpful for recursive definitions. | |
*/ | |
def suspend[A](a: => IO[A]): IO[A] = | |
new IO(Suspend(Step(() => Try(a).join.rep))) | |
/** | |
* Evaluate the given expression asynchronously in a new logical thread. | |
*/ | |
def apply[A](a: => A): IO[A] = | |
fork(delay(a)) | |
/** | |
* Explicitly fork the given `IO` as a new logical thread. | |
*/ | |
def fork[A](a: => IO[A]): IO[A] = | |
new IO(Suspend(Strat(s => a.and(now(()))((r, _) => r).rep))) | |
/** | |
* Turn a callback-accepting function into an `IO`. Helpful when working with | |
* APIs that expect explicit registering of callbacks. | |
*/ | |
def async[A](register: ((Throwable \/ A) => Unit) => Unit): IO[A] = | |
suspend { | |
@volatile var r: Rep[A] = null | |
val latch = new CountDownLatch(1) | |
def cb(v: Rep[A]): Unit = { | |
r = v | |
latch.countDown | |
} | |
register { | |
case -\/(e) => cb(Suspend(Error(e))) | |
case \/-(a) => cb(Return(a)) | |
} | |
latch.await | |
new IO(r) | |
} | |
/** | |
* Locally change the parallelization strategy for the given `IO`. | |
*/ | |
def withStrategy[A](s: Strategy)(io: IO[A]): IO[A] = | |
new IO(io.rep.mapSuspension(new (OI ~> OI) { | |
def apply[T](oi: OI[T]) = oi match { | |
case Strat(k) => Strat(_ => k(s)) | |
case x => x | |
} | |
})) | |
implicit val ioMonad: Monad[IO] = new Monad[IO] { | |
def bind[A,B](io: IO[A])(f: A => IO[B]) = io flatMap f | |
def point[A](a: => A) = delay(a) | |
} | |
def stepRep[A](r: Rep[A]): Rep[A] = r.resume match { | |
case -\/(Step(thunk)) => stepRep(thunk()) | |
case _ => r | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment