Skip to content

Instantly share code, notes, and snippets.

@kalexmills
Last active November 9, 2019 16:54
Show Gist options
  • Save kalexmills/de5aba23031678c1880b3489190e6990 to your computer and use it in GitHub Desktop.
Save kalexmills/de5aba23031678c1880b3489190e6990 to your computer and use it in GitHub Desktop.
Efficient draws from probability distributions in Scala -- using cats
import cats.data.Reader
import scala.reflect.ClassTag
/**
* Distribution takes a PFloat and returns an A.
*/
type Distribution[A] = Reader[PFloat, A]
/**
* PFloat represents a Float describing a probability. Its value is always between
* 0 and 1.
*/
case class PFloat private(value: Float) extends AnyVal
object PFloat {
/**
* from creates a PFloat from float. Values outside the interval [0,1]
* are clamped to the endpoints of the interval [0,1].
*/
def from(float: Float): PFloat =
float match {
case x if x < 0 => PFloat(0)
case x if x <= 1 => PFloat(x)
case x if x > 1 => PFloat(1)
}
implicit def fromDouble(d: Double): PFloat = from(d.toFloat)
implicit def fromFloat(f: Float): PFloat = from(f)
def random(): PFloat = PFloat(Math.random()) // TODO: think critically regarding loss of precision & implications for uniform randomness
}
/**
* PrefixSumDistribution builds a Distribution in O(n) time from a sequence of tuples.
* After building, the distribution provides samples in O(log n) time.
*/
object PrefixSumDistribution {
def from[A: ClassTag, N: Fractional: ClassTag](seq: Seq[(N, A)]): Distribution[A] = {
val FN = implicitly[Fractional[N]]
lazy val weights = seq.map(_._1)
lazy val total = weights.sum
lazy val values: Array[A] = seq.map(_._2).toArray
lazy val prefixSums: Array[Float] =
weights.foldLeft(Array.empty[N]) { (result, e) =>
result :+ FN.plus(result.lastOption.getOrElse(FN.zero), e)
}.map(x => FN.toFloat(FN.div(x, total)))
.toArray
Reader((pfloat: PFloat) => {
val idx = java.util.Arrays.binarySearch(prefixSums, pfloat.value)
val idx2 = if (idx < 0) -(idx + 1) else idx
values(idx2)
})
}
}
// usage:
val d = PrefixSumDistribution.from(List((1,'a'),(6,'b'),(1,'c'))
d.run(PFloat.random()) // samples 'a' with probability 1/8, 'b' with probability 3/4, 'c' with probability 1/8
import org.scalatest.{FlatSpecLike, Matchers}
class DistributionSpec extends FlatSpecLike with Matchers {
"Distribution" should "be creatable from a list" in {
val dist: Distribution[Int] = PrefixSumDistribution.from(List((0.5,5), (0.5,10)))
dist.run(PFloat(0.2f)) should be (5)
dist.run(PFloat(0.6f)) should be (10)
dist.run(PFloat(0.5f)) should be (5)
dist.run(PFloat(0f)) should be (5)
dist.run(PFloat(1f)) should be (10)
}
"Distribution" should "work on lists of size 1" in {
val dist: Distribution[Int] = PrefixSumDistribution.from(List((0.5,5)))
dist.run(PFloat(0.0f)) should be (5)
dist.run(PFloat(0.5f)) should be (5)
dist.run(PFloat(1.0f)) should be (5)
}
"Distribution" should "work on large lists" in {
val dist: Distribution[Int] = PrefixSumDistribution.from(List((2f,1), (2f,2), (2f,3), (2f,4), (2f,5), (2f,6), (2f,7), (2f,8), (2f,9)))
dist.run(PFloat(0.0f)) should be (1)
dist.run(PFloat(0.1f)) should be (1)
dist.run(PFloat(0.2f)) should be (2)
dist.run(PFloat(0.3f)) should be (3)
dist.run(PFloat(0.4f)) should be (4)
dist.run(PFloat(0.5f)) should be (5)
dist.run(PFloat(0.6f)) should be (6)
dist.run(PFloat(0.7f)) should be (7)
dist.run(PFloat(0.8f)) should be (8)
dist.run(PFloat(0.9f)) should be (9)
dist.run(PFloat(1.0f)) should be (9)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment