Created
May 24, 2019 16:46
-
-
Save ejconlon/cafad644948388e1807e7b68508be0bb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.annotation.tailrec | |
import scala.collection.immutable.SortedMap | |
import scala.collection.mutable | |
import scala.language.higherKinds | |
import scala.util.Random | |
import scala.util.control.TailCalls.{TailRec, done} | |
sealed trait KArrow[F[_], A, B] extends Product with Serializable | |
object KArrow { | |
final case class KMap[F[_], A, B](f: A => B) extends KArrow[F, A, B] | |
final case class KFlatMap[F[_], A, B](f: A => F[B]) extends KArrow[F, A, B] | |
} | |
trait TransBind[F[_], G[_]] { | |
def apply[A, B](what: G[A], arr: KArrow[F, A, B]): TailRec[G[B]] | |
} | |
object TAL { | |
def apply[F[_], A, B](arr: KArrow[F, A, B]): TAL[F, A, B] = | |
new TAL(IndexedSeq(arr)) | |
private def consume[F[_], G[_], A, B]( | |
what: G[A], | |
transBind: TransBind[F, G], | |
arrIt: Iterator[KArrow[F, _, _]] | |
): TailRec[G[B]] = | |
if (arrIt.hasNext) { | |
val arr = arrIt.next.asInstanceOf[KArrow[F, A, B]] | |
transBind(what, arr).flatMap { whatNow => | |
consume(whatNow, transBind, arrIt) | |
} | |
} else { | |
done(what.asInstanceOf[G[B]]) | |
} | |
} | |
// Severely cheating at a type-aligned list thanks to erasure | |
final class TAL[F[_], A, B](val arrs: IndexedSeq[KArrow[F, _, _]]) extends AnyVal { | |
import TAL._ | |
def +:[Z](arr: KArrow[F, Z, A]): TAL[F, Z, B] = new TAL(arr +: arrs) | |
def :+[C](arr: KArrow[F, B, C]): TAL[F, A, C] = new TAL(arrs :+ arr) | |
def run[G[_]](what: G[A], transBind: TransBind[F, G]): TailRec[G[B]] = | |
consume(what, transBind, arrs.iterator) | |
} | |
// TODO make this a log-prob for accuracy, or rational etc | |
final class Prob(val value: Double) extends AnyVal | |
sealed trait Dist[A] extends Product with Serializable { | |
import Dist._ | |
import KArrow._ | |
final def map[B](f: A => B): Dist[B] = appendArr(KMap(f)) | |
final def flatMap[B](f: A => Dist[B]): Dist[B] = appendArr(KFlatMap(f)) | |
private[this] def appendArr[B](arr: KArrow[Dist, A, B]): Dist[B] = | |
this match { | |
case Bind(c, tal) => Bind(c, tal :+ arr) | |
case _ => Bind(this, TAL(arr)) | |
} | |
protected def sampleTailRec(random: Random): TailRec[A] | |
protected def supportTailRec: TailRec[Set[A]] | |
final def sample(random: Random): A = sampleTailRec(random).result | |
final def support: Set[A] = supportTailRec.result | |
} | |
object Dist { | |
import KArrow._ | |
case object EmptyException extends Exception("empty distribution") | |
private final case class Pure[A](value: A) extends Dist[A] { | |
override protected def sampleTailRec(random: Random): TailRec[A] = done(value) | |
override protected def supportTailRec: TailRec[Set[A]] = done(Set(value)) | |
} | |
private[this] type Identity[A] = A | |
private[this] final class SampleTransBind(random: Random) extends TransBind[Dist, Identity] { | |
override def apply[A, B]( | |
what: Identity[A], | |
arr: KArrow[Dist, A, B] | |
): TailRec[Identity[B]] = | |
arr match { | |
case KMap(f) => done(f(what)) | |
case KFlatMap(f) => | |
val d = f(what) | |
d.sampleTailRec(random) | |
} | |
} | |
private[this] object SupportTransBind extends TransBind[Dist, Set] { | |
private[this] def subApply[A, B]( | |
builder: mutable.Builder[B, Set[B]], | |
arr: KArrow[Dist, A, B], | |
aIt: Iterator[A] | |
): TailRec[mutable.Builder[B, Set[B]]] = | |
if (aIt.isEmpty) { | |
done(builder) | |
} else { | |
val a = aIt.next | |
arr match { | |
case KMap(f) => | |
val b = f(a) | |
builder += b | |
subApply(builder, arr, aIt) | |
case KFlatMap(f) => | |
val d = f(a) | |
d.supportTailRec.flatMap { bs => | |
builder ++= bs | |
subApply(builder, arr, aIt) | |
} | |
} | |
} | |
override def apply[A, B]( | |
what: Set[A], | |
arr: KArrow[Dist, A, B] | |
): TailRec[Set[B]] = | |
subApply(Set.newBuilder[B], arr, what.iterator).map { | |
_.result | |
} | |
} | |
private final case class Bind[Z, A](context: Dist[Z], tal: TAL[Dist, Z, A]) extends Dist[A] { | |
override protected def sampleTailRec(random: Random): TailRec[A] = | |
context.sampleTailRec(random).flatMap { z => | |
tal.run[Identity](z, new SampleTransBind(random)) | |
} | |
override protected def supportTailRec: TailRec[Set[A]] = | |
context.supportTailRec.flatMap { zs => | |
if (zs.isEmpty) { | |
done(Set.empty[A]) | |
} else { | |
tal.run(zs, SupportTransBind) | |
} | |
} | |
} | |
@tailrec | |
private[this] def subCategoricalSample[A]( | |
p: Double, | |
s: Double, | |
last: Option[A], | |
elemIt: Iterator[(A, Prob)] | |
): Option[A] = | |
if (elemIt.hasNext) { | |
val (elem, prob) = elemIt.next() | |
val t = s + prob.value | |
val seen = Some(elem) | |
if (t > p) { | |
seen | |
} else { | |
subCategoricalSample(p, t, seen, elemIt) | |
} | |
} else { | |
last | |
} | |
private final case class Categorical[A](elems: SortedMap[A, Prob]) extends Dist[A] { | |
override protected def sampleTailRec(random: Random): TailRec[A] = | |
subCategoricalSample(random.nextDouble(), 0, None, elems.iterator) match { | |
case None => throw EmptyException | |
case Some(v) => done(v) | |
} | |
override protected def supportTailRec: TailRec[Set[A]] = | |
done(elems.keySet) | |
} | |
private final case class Uniform[A](elems: IndexedSeq[A]) extends Dist[A] { | |
override protected def sampleTailRec(random: Random): TailRec[A] = | |
done(elems(random.nextInt(elems.size))) | |
override protected def supportTailRec: TailRec[Set[A]] = | |
done(elems.toSet) | |
} | |
def apply[A](value: A): Dist[A] = | |
Pure(value) | |
// TODO sum prob and normalize. move to IndexedSeq | |
def categorical[A](elems: SortedMap[A, Prob]): Dist[A] = | |
if (elems.isEmpty) { | |
throw EmptyException | |
} else { | |
Categorical(elems) | |
} | |
def uniform[A](elems: IndexedSeq[A]): Dist[A] = | |
if (elems.isEmpty) { | |
throw EmptyException | |
} else { | |
Uniform(elems) | |
} | |
} | |
object DistMain { | |
def main(args: Array[String]): Unit = { | |
val s = 42 | |
val r = new Random(s) | |
val p = for { | |
x <- Dist.uniform(IndexedSeq(-1, 1)) | |
y <- if (x > 0) Dist.uniform(IndexedSeq("a")) else Dist.uniform(IndexedSeq("b", "c")) | |
} yield { | |
y | |
} | |
println(p.sample(r)) | |
println(p.support) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment