Last active
November 9, 2019 16:54
-
-
Save kalexmills/de5aba23031678c1880b3489190e6990 to your computer and use it in GitHub Desktop.
Efficient draws from probability distributions in Scala -- using cats
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cats.data.Reader | |
import scala.reflect.ClassTag | |
/** | |
* Distribution takes a PFloat and returns an A. | |
*/ | |
type Distribution[A] = Reader[PFloat, A] | |
/** | |
* PFloat represents a Float describing a probability. Its value is always between | |
* 0 and 1. | |
*/ | |
case class PFloat private(value: Float) extends AnyVal | |
object PFloat { | |
/** | |
* from creates a PFloat from float. Values outside the interval [0,1] | |
* are clamped to the endpoints of the interval [0,1]. | |
*/ | |
def from(float: Float): PFloat = | |
float match { | |
case x if x < 0 => PFloat(0) | |
case x if x <= 1 => PFloat(x) | |
case x if x > 1 => PFloat(1) | |
} | |
implicit def fromDouble(d: Double): PFloat = from(d.toFloat) | |
implicit def fromFloat(f: Float): PFloat = from(f) | |
def random(): PFloat = PFloat(Math.random()) // TODO: think critically regarding loss of precision & implications for uniform randomness | |
} | |
/** | |
* PrefixSumDistribution builds a Distribution in O(n) time from a sequence of tuples. | |
* After building, the distribution provides samples in O(log n) time. | |
*/ | |
object PrefixSumDistribution { | |
def from[A: ClassTag, N: Fractional: ClassTag](seq: Seq[(N, A)]): Distribution[A] = { | |
val FN = implicitly[Fractional[N]] | |
lazy val weights = seq.map(_._1) | |
lazy val total = weights.sum | |
lazy val values: Array[A] = seq.map(_._2).toArray | |
lazy val prefixSums: Array[Float] = | |
weights.foldLeft(Array.empty[N]) { (result, e) => | |
result :+ FN.plus(result.lastOption.getOrElse(FN.zero), e) | |
}.map(x => FN.toFloat(FN.div(x, total))) | |
.toArray | |
Reader((pfloat: PFloat) => { | |
val idx = java.util.Arrays.binarySearch(prefixSums, pfloat.value) | |
val idx2 = if (idx < 0) -(idx + 1) else idx | |
values(idx2) | |
}) | |
} | |
} | |
// usage: | |
val d = PrefixSumDistribution.from(List((1,'a'),(6,'b'),(1,'c')) | |
d.run(PFloat.random()) // samples 'a' with probability 1/8, 'b' with probability 3/4, 'c' with probability 1/8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.scalatest.{FlatSpecLike, Matchers} | |
class DistributionSpec extends FlatSpecLike with Matchers { | |
"Distribution" should "be creatable from a list" in { | |
val dist: Distribution[Int] = PrefixSumDistribution.from(List((0.5,5), (0.5,10))) | |
dist.run(PFloat(0.2f)) should be (5) | |
dist.run(PFloat(0.6f)) should be (10) | |
dist.run(PFloat(0.5f)) should be (5) | |
dist.run(PFloat(0f)) should be (5) | |
dist.run(PFloat(1f)) should be (10) | |
} | |
"Distribution" should "work on lists of size 1" in { | |
val dist: Distribution[Int] = PrefixSumDistribution.from(List((0.5,5))) | |
dist.run(PFloat(0.0f)) should be (5) | |
dist.run(PFloat(0.5f)) should be (5) | |
dist.run(PFloat(1.0f)) should be (5) | |
} | |
"Distribution" should "work on large lists" in { | |
val dist: Distribution[Int] = PrefixSumDistribution.from(List((2f,1), (2f,2), (2f,3), (2f,4), (2f,5), (2f,6), (2f,7), (2f,8), (2f,9))) | |
dist.run(PFloat(0.0f)) should be (1) | |
dist.run(PFloat(0.1f)) should be (1) | |
dist.run(PFloat(0.2f)) should be (2) | |
dist.run(PFloat(0.3f)) should be (3) | |
dist.run(PFloat(0.4f)) should be (4) | |
dist.run(PFloat(0.5f)) should be (5) | |
dist.run(PFloat(0.6f)) should be (6) | |
dist.run(PFloat(0.7f)) should be (7) | |
dist.run(PFloat(0.8f)) should be (8) | |
dist.run(PFloat(0.9f)) should be (9) | |
dist.run(PFloat(1.0f)) should be (9) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment