Last active
August 11, 2020 17:15
-
-
Save emesday/87e877bd21711dcf1fb8e4a2deed032d to your computer and use it in GitHub Desktop.
Reservoir Sampling for Scala Spark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.reflect.ClassTag | |
import scala.util.Random | |
class Reservoir[T: ClassTag]( | |
private val size: Int, | |
private val seed: Long = Random.nextLong()) extends Serializable { | |
private val rand = new Random(seed) | |
private val reservoir = new Array[T](size) | |
private var count = 0L | |
def +=(elem: T): this.type = { | |
count += 1 | |
if (count <= size) { | |
reservoir((count - 1).toInt) = elem | |
} else { | |
val replacementIndex = (rand.nextDouble() * count).toLong | |
if (replacementIndex < size) { | |
reservoir(replacementIndex.toInt) = elem | |
} | |
} | |
this | |
} | |
def ++=(that: Reservoir[T]): this.type = { | |
require(this.size == that.size) | |
if ((this.count + that.count) < size) { | |
compat.Platform.arraycopy( | |
that.reservoir, 0, this.reservoir, this.count.toInt, that.count.toInt) | |
} else { | |
val thisIterator = rand.shuffle[Int, IndexedSeq](0 until this.getSize).iterator | |
val thatIterator = rand.shuffle[Int, IndexedSeq](0 until that.getSize).iterator | |
val thisProb = this.count.toDouble / (this.count + that.count) | |
val newReservoir = Array.fill[T](size) { | |
if (thisIterator.isEmpty) { | |
that.reservoir(thatIterator.next()) | |
} else if (thatIterator.isEmpty) { | |
this.reservoir(thisIterator.next()) | |
} else { | |
if (rand.nextDouble() < thisProb) { | |
this.reservoir(thisIterator.next()) | |
} else { | |
that.reservoir(thatIterator.next()) | |
} | |
} | |
} | |
compat.Platform.arraycopy(newReservoir, 0, reservoir, 0, newReservoir.length) | |
} | |
count += that.count | |
this | |
} | |
def getCount: Long = count | |
def getSize: Int = math.min(size, count).toInt | |
def result(): Array[T] = reservoir.take(getSize) | |
} |
Author
emesday
commented
Mar 12, 2019
•
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment