Skip to content

Instantly share code, notes, and snippets.

@RobColeman
Last active January 18, 2018 17:21
Show Gist options
  • Save RobColeman/c4c948f6365dc788a09d to your computer and use it in GitHub Desktop.
Save RobColeman/c4c948f6365dc788a09d to your computer and use it in GitHub Desktop.
A fast, serializable, Scala implementation of tDigest. Thrown together for a Spark project.
import java.nio.ByteBuffer
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.commons.math3.distribution.ExponentialDistribution
import scala.collection._
import scala.collection.generic.CanBuildFrom
import scala.util.Random
case class Centroid(var mean: Double, var count: Long) extends Ordered[Centroid] with Serializable {
def update(x: Double, weight: Long): Unit = {
this.count += weight
this.mean += weight * (x - this.mean) / this.count
}
def compare(that: Centroid): Int = this.mean compare that.mean
override def equals(o: Any): Boolean = o match {
case that: Centroid => that.mean == this.mean
case _ => false
}
override def hashCode: Int = this.mean.hashCode()
}
case class TDigest(delta: Double = 0.01,
k: Int = 25,
var n: Long = 0,
// speedups with better DS
val centroids: mutable.TreeSet[Centroid] = mutable.TreeSet[Centroid]()) extends Serializable {
def size: Int = centroids.size
def ++(other: TDigest): TDigest = {
val bothCentroids: Seq[Centroid] = Random.shuffle(other.centroids.toSeq ++ this.centroids.toSeq)
val newDigest = TDigest(this.delta, this.k)
bothCentroids.foreach{ c => newDigest.addCentroid(c) }
//newDigest.compress
newDigest
}
def addCentroid(c: Centroid, increment: Boolean = false): Unit = {
if (increment) { this.n += c.count }
this.centroids.contains(c) match {
case true => this.updateCentroid(c, c.mean, c.count)
case false => this.centroids.add(c)
}
}
def computeCentroidQuantile(centroid: Centroid): Double = {
this.centroids.filter{ c => c.mean <= centroid.mean }.map{ _.count }.sum / this.n.toDouble
}
def updateCentroid(c: Centroid, x: Double, weight: Long): Unit = {
centroids.find{ _ == c } match {
case None =>
case Some(c) => c.update(x, weight)
}
}
// maybe protected
def findClosestCentroids(x: Double): mutable.TreeSet[Centroid] = {
val (below, above) = this.centroids.partition{c => c.mean < x}
val out: mutable.TreeSet[Centroid] = mutable.TreeSet[Centroid]()
(below.lastOption ++ above.headOption).foreach{ c => out.add(c) }
out
}
def threshold(q: Double): Long = Math.round(4 * this.n * this.delta * q * (1 - q))
// check this
def compress: Unit = {
val oldCentroids: Seq[Centroid] = Random.shuffle(this.centroids.toSeq)
this.centroids.clear
oldCentroids.foreach{ c => this.update(c.mean, c.count) }
}
// insert a new element
def update(x: Double, weight: Long = 1): Unit = {
this.n += weight
this.size > 0 match {
case false => this.addCentroid(Centroid(x, weight))
case true =>
val S: mutable.TreeSet[Centroid] = this.findClosestCentroids(x)
var w: Long = weight
Random.shuffle(S.toSeq).foreach{ c =>
val q: Double = this.computeCentroidQuantile(c)
val delta_w: Long = Seq(this.threshold(q) - c.count, w).min
if ((w > 0) && ((c.count + w) <= this.threshold(q))) {
this.updateCentroid(c, x, delta_w)
w -= delta_w
}
}
if (w > 0) {
this.addCentroid(Centroid(x, weight))
}
/*
if (this.size > (this.k / this.delta)) {
this.compress
}
*/
}
}
def batchUpdate(X: Seq[Double], weight: Long = 1): Unit = {
X.foreach( x => this.update( x, weight ))
//this.compress
}
def invCDF(p: Double): Double = {
val cumProb: Seq[(Double, Centroid)] = this.centroids.map{ _.count / this.n.toDouble }
.scanLeft(0.0)( _ + _ ).zip(this.centroids).toSeq
val above: Option[(Double, Centroid)] = cumProb.find{ _._1 > p }
val below: Option[(Double, Centroid)] = cumProb.reverse.find(_._1 < p)
(below, above) match {
case (None, None) => -1.0 // raise error here, we don't have any centroids
case (None, Some(aC)) => aC._2.mean
case (Some(bC), None) => bC._2.mean
case (Some(bC), Some(aC)) =>
// linear interpolation between means
val deltaX: Double = aC._2.mean - bC._2.mean
val deltaP = (p - bC._1) / (aC._1 - bC._1)
bC._2.mean + ( deltaP * deltaX )
}
}
def cdf(x: Double): Double = {
val cumCount: Seq[(Double, Centroid)] = this.centroids.toSeq.map{ _.count / this.n.toDouble }.scanLeft(0.0)( _ + _ ).zip(this.centroids)
val above: Option[(Double, Centroid)] = cumCount.find{ _._2.mean > x }
val below: Option[(Double, Centroid)] = cumCount.reverse.find{ _._2.mean < x }
(below, above) match {
case (None, None) => -1.0 // raise error here, we don't have any centroids
case (None, Some(aC)) => aC._1
case (Some(bC), None) => bC._1
case (Some(bC), Some(aC)) =>
// piece-wise uniform, distribution
val deltaX: Double = (x - bC._2.mean) / (aC._2.mean - bC._2.mean)
val deltaP = aC._1 - bC._1
bC._1 + ( deltaP * deltaX )
}
}
def trimmedMean(x0: Double, x1: Double): Double = {
// the mean value, from a window of the distribution
val within = this.centroids.filter{ c => c.mean > x0 && c.mean < x1 }
val s = within.map{ _.count.toDouble }.sum
val trimmedMean = within.map{ c => c.mean * (c.count / s) }.sum
trimmedMean
}
}
object TDigestAppCustom {
def main(arg: Array[String]): Unit = {
val appName: String = "TDigest-Test"
val conf: SparkConf = new SparkConf().setAppName(appName).setMaster("local[16]")
val sc: SparkContext = new SparkContext(conf)
val trueDist0: ExponentialDistribution = new ExponentialDistribution(15)
val trueDist1: ExponentialDistribution = new ExponentialDistribution(30)
val data0: immutable.Seq[Double] = (0 until 10000).map{ i => trueDist0.sample()}.toSeq
val data1: immutable.Seq[Double] = (0 until 10000).map{ i => trueDist1.sample()}.toSeq
val TD0: TDigest = new TDigest()
val TD1: TDigest = new TDigest()
TD0.batchUpdate(data0, 1)
TD1.batchUpdate(data1, 1)
println(TD0.cdf(15.0))
println(TD1.cdf(30.0))
println(TD0.invCDF(0.50))
println(TD1.invCDF(0.50))
val bothTD: TDigest = TD0 ++ TD1
println(bothTD.cdf(15.0))
println(bothTD.cdf(30.0))
println(bothTD.invCDF(0.50))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment