Created
August 12, 2016 04:06
-
-
Save redwrasse/a91cb9fd519741ae083bf7f229d3727c to your computer and use it in GitHub Desktop.
Distributed median binning with spark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* Distributed median binning | |
* | |
* See | |
* "Fast Computation of the Median by Successive Binning" | |
* https://www.stat.cmu.edu/~ryantibs/papers/median.pdf | |
* | |
* This code currently only works for an odd number of elements | |
* See https://github.com/goodsoldiersvejk/medianbinning | |
*/ | |
import org.apache.spark | |
.{SparkContext} | |
import org.apache.spark.rdd | |
.{RDD} | |
import org.apache.spark | |
.{Partitioner} | |
import org.slf4j | |
.{Logger, LoggerFactory} | |
object MedianBinning { | |
val logger = LoggerFactory.getLogger(getClass) | |
class BinPartitioner(numBins: Int, minValue: Int, | |
maxValue: Int) extends Partitioner { | |
val binSize = (maxValue - minValue) * 1.0 / numBins | |
def getPartition(key: Any): Int = key match { | |
case k: Int => { | |
((k - minValue) * 1.0 / binSize).toInt match { | |
case n if n == numBins => n - 1 | |
case n => n | |
} | |
} | |
case _ => 0 | |
} | |
def numPartitions: Int = numBins | |
} | |
def main(args: Array[String]) = { | |
val sc = new SparkContext() | |
val sampleRdd = sc.parallelize( | |
List(6,1,1,4,12,7,8,25,8)) | |
val numBins = 3 | |
val trueMedian = 7 | |
val calculatedMedian = findMedian(sampleRdd, numBins) | |
println("NUMBER OF BINS: " + numBins) | |
println("TRUE MEDIAN: " + trueMedian) | |
println("CALCULATED MEDIAN: " + calculatedMedian) | |
sc.stop() | |
} | |
def findMedian(rdd: RDD[Int], numBins: Int): Int = { | |
val totalCt: Long = rdd.count | |
val halfCt = totalCt / 2 + 1 | |
/** | |
* Returns a pair (rdd for new bin, updated left count) | |
*/ | |
def findMedianBin(currentRdd: RDD[Int], leftCount: Long): | |
(RDD[Int], Long) = { | |
val (minValue, maxValue) = (currentRdd.min, currentRdd.max) | |
val binPartitionedRdd = currentRdd.map(e => (e, e)) | |
.partitionBy(new BinPartitioner(numBins, minValue, maxValue)) | |
val binCounts: Array[(Int, Long)] = binPartitionedRdd | |
.mapPartitionsWithIndex((i, it) => Iterator((i, it.size.toLong)), | |
preservesPartitioning=true).collect | |
var i = 0 | |
var sm = leftCount | |
while (sm < halfCt) { | |
sm += binCounts(i)._2 | |
i += 1 | |
} | |
val medianBin = i - 1 | |
val newLeftCt = sm - binCounts(medianBin)._2 | |
val newRdd: RDD[Int] = binPartitionedRdd | |
.mapPartitionsWithIndex((i, it) => it.map(e => (i, e))) | |
.filter({case (i, e) => i == medianBin}) | |
.map({case (i, e) => e._1}) | |
(newRdd, newLeftCt) | |
} | |
var leftCt: Long = 0L | |
var medianBinCt = 1000L | |
var currentRdd = rdd | |
while (medianBinCt > 1L) { | |
findMedianBin(currentRdd, leftCt) match { | |
case (a,b) => currentRdd = a; leftCt = b | |
} | |
medianBinCt = currentRdd.count | |
} | |
val median = currentRdd.first | |
median | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment