Skip to content

Instantly share code, notes, and snippets.

@felipecrv
Last active July 19, 2018 19:14
Show Gist options
  • Save felipecrv/2782a4c54408edbd13ad3a57dfc8d2ec to your computer and use it in GitHub Desktop.
Save felipecrv/2782a4c54408edbd13ad3a57dfc8d2ec to your computer and use it in GitHub Desktop.
Sampling and Shuffling sequences in Scala
package m79.random
import java.util.Random
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
object Sampling {
/**
* Randomly sample up to k items from a sequence.
*/
def sample[T](sequence: IndexedSeq[T], k: Int, rng: Random): ArrayBuffer[T] = {
val n = sequence.length
if (k < n / 2) {
smallSample(sequence, k, rng)
} else {
reservoirSample(sequence.toIterator, k, rng)
}
}
/**
* Randomly sample up to k items from a sequence and guarantee that the sample is shuffled.
*/
def shuffledSample[T](sequence: IndexedSeq[T], k: Int, rng: Random): ArrayBuffer[T] = {
val n = sequence.length
if (k < n / 2) {
smallSample(sequence, k, rng) // already shuffled
} else {
val sample = reservoirSample(sequence.toIterator, k, rng)
shuffle(sample, rng)
sample
}
}
/**
* Randomly sample k items from an iterator.
*
* Algorithm R -- See https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_R
* Vitter, Jeffrey S. "Random sampling with a reservoir", 1985.
*
* @return Sample of up to k items. Not properly shuffled.
*/
def reservoirSample[T](iterator: Iterator[T], k: Int, rng: Random): ArrayBuffer[T] = {
val sample = new ArrayBuffer[T](k)
var i = 0
while (i < k && iterator.hasNext) {
val elem = iterator.next()
sample.append(elem)
i += 1
}
while (iterator.hasNext) {
i += 1
val elem = iterator.next()
val r = rng.nextInt(i)
if (r < k) {
sample(r) = elem
}
}
sample
}
/**
* Randomly sample k items from a sequence of length n.
* This method is preferrable if k is smaller than n.
*
* @return Shuffled sample of up to k items.
*/
def smallSample[T](sequence: IndexedSeq[T], k: Int, rng: Random): ArrayBuffer[T] = {
val n = sequence.length
var i = 0
val chosenIndices = mutable.Set[Int]()
while (i < n && i < k) {
val r = rng.nextInt(n)
if (!chosenIndices.contains(r)) {
chosenIndices += r
i += 1
}
}
val sample = new ArrayBuffer[T](k)
chosenIndices.foreach(r => sample.append(sequence(r)))
sample
}
/**
* In-place Knuth-Fisher-Yates shuffle.
*
* See https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
*/
def shuffle[T](arr: ArrayBuffer[T], rng: Random): Unit = {
for (i <- arr.length - 1 to 1 by -1) {
val r = rng.nextInt(i + 1)
val tmp = arr(i)
arr(i) = arr(r)
arr(r) = tmp
}
}
}
package m79.random
import java.util.Random
import m79.random.Sampling.{reservoirSample, shuffle, smallSample}
import org.scalatest.FunSpec
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
class SamplingSpec extends FunSpec {
val rng = new Random(12)
val empty = IndexedSeq.empty[Int]
val small = IndexedSeq(2, 3, 4)
val population = IndexedSeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
describe("Sampling") {
describe("reservoiSample") {
it("samples on empty populations") {
assert(reservoirSample(empty.toIterator, -5, rng) == empty)
assert(reservoirSample(empty.toIterator, 0, rng) == empty)
assert(reservoirSample(empty.toIterator, 5, rng) == empty)
}
it("samples on small populations") {
assert(reservoirSample(small.toIterator, -5, rng) == empty)
assert(reservoirSample(small.toIterator, 0, rng) == empty)
assert(reservoirSample(small.toIterator, 5, rng) == Seq(2, 3, 4))
}
it("samples on bigger populations") {
assert(reservoirSample(population.toIterator, -5, rng) == empty)
assert(reservoirSample(population.toIterator, 0, rng) == empty)
assert(reservoirSample(population.toIterator, 5, rng) == Seq(0, 9, 2, 7, 4))
assert(reservoirSample(population.toIterator, 10, rng) == population)
}
}
describe("smallSample") {
it("samples on empty populations") {
assert(smallSample(empty, -5, rng) == empty)
assert(smallSample(empty, 0, rng) == empty)
assert(smallSample(empty, 5, rng) == empty)
}
it("samples on small populations") {
assert(smallSample(small, -5, rng) == empty)
assert(smallSample(small, 0, rng) == empty)
assert(smallSample(small, 5, rng) == Seq(2, 3, 4))
}
it("samples on bigger populations") {
assert(smallSample(population, -5, rng) == empty)
assert(smallSample(population, 0, rng) == empty)
assert(smallSample(population, 5, rng) == Seq(9, 5, 2, 7, 8))
assert(smallSample(population, 10, rng) == Seq(0, 9, 1, 5, 2, 6, 3, 7, 4, 8))
}
}
describe("shuffle") {
it("creates an uniform distribution of samples") {
val numbers = new ArrayBuffer() ++= Seq(0, 1, 2, 3)
val freq = mutable.Map.empty[Int, Int]
val N = 1000000
for (_ <- 0 to N) {
shuffle(numbers, rng)
val x = numbers(0) * 1 + numbers(1) * 10 + numbers(2) * 100 + numbers(3) * 1000
freq.update(x, freq.getOrElse(x, 0) + 1)
}
val p = freq.map { case (x, f) => f.toDouble / N }
p.foreach(f => assert(f > 0.04 && f < 0.042))
}
}
}
}
@felipecrv
Copy link
Author

Implementations of Reservoir Sampling and Fisher-Yates Shuffle in Scala.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment