Last active
July 19, 2018 19:14
-
-
Save felipecrv/2782a4c54408edbd13ad3a57dfc8d2ec to your computer and use it in GitHub Desktop.
Sampling and Shuffling sequences in Scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package m79.random | |
import java.util.Random | |
import scala.collection.mutable | |
import scala.collection.mutable.ArrayBuffer | |
object Sampling { | |
/** | |
* Randomly sample up to k items from a sequence. | |
*/ | |
def sample[T](sequence: IndexedSeq[T], k: Int, rng: Random): ArrayBuffer[T] = { | |
val n = sequence.length | |
if (k < n / 2) { | |
smallSample(sequence, k, rng) | |
} else { | |
reservoirSample(sequence.toIterator, k, rng) | |
} | |
} | |
/** | |
* Randomly sample up to k items from a sequence and guarantee that the sample is shuffled. | |
*/ | |
def shuffledSample[T](sequence: IndexedSeq[T], k: Int, rng: Random): ArrayBuffer[T] = { | |
val n = sequence.length | |
if (k < n / 2) { | |
smallSample(sequence, k, rng) // already shuffled | |
} else { | |
val sample = reservoirSample(sequence.toIterator, k, rng) | |
shuffle(sample, rng) | |
sample | |
} | |
} | |
/** | |
* Randomly sample k items from an iterator. | |
* | |
* Algorithm R -- See https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_R | |
* Vitter, Jeffrey S. "Random sampling with a reservoir", 1985. | |
* | |
* @return Sample of up to k items. Not properly shuffled. | |
*/ | |
def reservoirSample[T](iterator: Iterator[T], k: Int, rng: Random): ArrayBuffer[T] = { | |
val sample = new ArrayBuffer[T](k) | |
var i = 0 | |
while (i < k && iterator.hasNext) { | |
val elem = iterator.next() | |
sample.append(elem) | |
i += 1 | |
} | |
while (iterator.hasNext) { | |
i += 1 | |
val elem = iterator.next() | |
val r = rng.nextInt(i) | |
if (r < k) { | |
sample(r) = elem | |
} | |
} | |
sample | |
} | |
/** | |
* Randomly sample k items from a sequence of length n. | |
* This method is preferrable if k is smaller than n. | |
* | |
* @return Shuffled sample of up to k items. | |
*/ | |
def smallSample[T](sequence: IndexedSeq[T], k: Int, rng: Random): ArrayBuffer[T] = { | |
val n = sequence.length | |
var i = 0 | |
val chosenIndices = mutable.Set[Int]() | |
while (i < n && i < k) { | |
val r = rng.nextInt(n) | |
if (!chosenIndices.contains(r)) { | |
chosenIndices += r | |
i += 1 | |
} | |
} | |
val sample = new ArrayBuffer[T](k) | |
chosenIndices.foreach(r => sample.append(sequence(r))) | |
sample | |
} | |
/** | |
* In-place Knuth-Fisher-Yates shuffle. | |
* | |
* See https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle | |
*/ | |
def shuffle[T](arr: ArrayBuffer[T], rng: Random): Unit = { | |
for (i <- arr.length - 1 to 1 by -1) { | |
val r = rng.nextInt(i + 1) | |
val tmp = arr(i) | |
arr(i) = arr(r) | |
arr(r) = tmp | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package m79.random | |
import java.util.Random | |
import m79.random.Sampling.{reservoirSample, shuffle, smallSample} | |
import org.scalatest.FunSpec | |
import scala.collection.mutable | |
import scala.collection.mutable.ArrayBuffer | |
class SamplingSpec extends FunSpec { | |
val rng = new Random(12) | |
val empty = IndexedSeq.empty[Int] | |
val small = IndexedSeq(2, 3, 4) | |
val population = IndexedSeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) | |
describe("Sampling") { | |
describe("reservoiSample") { | |
it("samples on empty populations") { | |
assert(reservoirSample(empty.toIterator, -5, rng) == empty) | |
assert(reservoirSample(empty.toIterator, 0, rng) == empty) | |
assert(reservoirSample(empty.toIterator, 5, rng) == empty) | |
} | |
it("samples on small populations") { | |
assert(reservoirSample(small.toIterator, -5, rng) == empty) | |
assert(reservoirSample(small.toIterator, 0, rng) == empty) | |
assert(reservoirSample(small.toIterator, 5, rng) == Seq(2, 3, 4)) | |
} | |
it("samples on bigger populations") { | |
assert(reservoirSample(population.toIterator, -5, rng) == empty) | |
assert(reservoirSample(population.toIterator, 0, rng) == empty) | |
assert(reservoirSample(population.toIterator, 5, rng) == Seq(0, 9, 2, 7, 4)) | |
assert(reservoirSample(population.toIterator, 10, rng) == population) | |
} | |
} | |
describe("smallSample") { | |
it("samples on empty populations") { | |
assert(smallSample(empty, -5, rng) == empty) | |
assert(smallSample(empty, 0, rng) == empty) | |
assert(smallSample(empty, 5, rng) == empty) | |
} | |
it("samples on small populations") { | |
assert(smallSample(small, -5, rng) == empty) | |
assert(smallSample(small, 0, rng) == empty) | |
assert(smallSample(small, 5, rng) == Seq(2, 3, 4)) | |
} | |
it("samples on bigger populations") { | |
assert(smallSample(population, -5, rng) == empty) | |
assert(smallSample(population, 0, rng) == empty) | |
assert(smallSample(population, 5, rng) == Seq(9, 5, 2, 7, 8)) | |
assert(smallSample(population, 10, rng) == Seq(0, 9, 1, 5, 2, 6, 3, 7, 4, 8)) | |
} | |
} | |
describe("shuffle") { | |
it("creates an uniform distribution of samples") { | |
val numbers = new ArrayBuffer() ++= Seq(0, 1, 2, 3) | |
val freq = mutable.Map.empty[Int, Int] | |
val N = 1000000 | |
for (_ <- 0 to N) { | |
shuffle(numbers, rng) | |
val x = numbers(0) * 1 + numbers(1) * 10 + numbers(2) * 100 + numbers(3) * 1000 | |
freq.update(x, freq.getOrElse(x, 0) + 1) | |
} | |
val p = freq.map { case (x, f) => f.toDouble / N } | |
p.foreach(f => assert(f > 0.04 && f < 0.042)) | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Implementations of Reservoir Sampling and Fisher-Yates Shuffle in Scala.