Skip to content

Instantly share code, notes, and snippets.

@JoshRosen
Last active August 29, 2015 14:19
Show Gist options
  • Save JoshRosen/6181f667bd69c85c9529 to your computer and use it in GitHub Desktop.
Save JoshRosen/6181f667bd69c85c9529 to your computer and use it in GitHub Desktop.
package org.apache.spark.sql
import org.apache.spark.unsafe.memory.{MemoryAllocator, TaskMemoryManager, ExecutorMemoryManager}
import scala.util.Random
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
/**
* This benchmark inserts records into an aggregation hash map.
*
* Each record has the following schema:
*
* (10 character string, int, double, long)
*
* For the hash map entries, the grouping key is the first string column and the values are the
* remaining numeric columns.
*
* Because the keys are completely random, they should pretty much all be distinct, so this is a
* good way to stress-test how the maps perform at large sizes.
*/
object AggregationHashMapBenchmark {
private val rowArr = new Array[Any](4)
private val row = new GenericRow(rowArr)
private val rand = new Random(42)
/**
* Generates a new random row. For efficiency, each call returns the same mutable row.
*/
def randomRow(): Row = {
rowArr(0) = UTF8String(rand.nextString(10).toString)
rowArr(1) = rand.nextInt()
rowArr(2) = rand.nextLong()
rowArr(3) = rand.nextDouble()
row
}
/**
* Generates the grouping projection for a given row (in this case, a new row containing only
* the string column).
*/
def groupProjection(row: Row): Row = {
new GenericRow(Array[Any](row.get(0)))
}
/** The schema of the map values (aggegation buffers) */
val aggregationBufferTypes = IntegerType :: LongType :: DoubleType :: Nil
/** Generates new empty values */
def emptyAggregationBuffer(): MutableRow = {
new SpecificMutableRow(aggregationBufferTypes)
}
/**
* Updates an aggregation buffer by adding the given row's values to it.
*/
def updateAggregationBuffer(buffer: MutableRow, row: Row): Unit = {
buffer.setInt(0, buffer.getInt(0) + row.getInt(1))
buffer.setLong(1, buffer.getLong(1) + row.getLong(2))
buffer.setDouble(2, buffer.getDouble(2) + row.getDouble(3))
}
/**
* Run the benchmark using a hashmap implementation based on Java objects.
*/
def benchmarkJavaObjects(numKeys: Int): Unit = {
val map = new java.util.HashMap[Row, MutableRow]()
var i = 0
while (i < numKeys) {
i += 1
val currentRow = randomRow()
val currentGroup = groupProjection(currentRow)
var currentBuffer = map.get(currentGroup)
if (currentBuffer == null) {
currentBuffer = emptyAggregationBuffer()
map.put(currentGroup, currentBuffer)
}
updateAggregationBuffer(currentBuffer, currentRow)
}
val mapIter = map.entrySet().iterator()
while (mapIter.hasNext) {
mapIter.next()
}
}
/**
* Run the benchmark using a hashmap implementation that uses managed memory.
*/
def benchmarkManagedMemory(numKeys: Int, useHeap: Boolean): Unit = {
val allocator = if (useHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE
val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(allocator))
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer(),
StructType(Seq(
StructField("a", IntegerType),
StructField("b", LongType),
StructField("c", DoubleType))),
StructType(Seq(
StructField("d", StringType))),
memoryManager,
1024, // initial map size
false) // disable perf. metrics
var i = 0
while (i < numKeys) {
i += 1
val currentRow = randomRow()
val currentGroup = groupProjection(currentRow)
val currentBuffer = map.getAggregationBuffer(currentGroup)
updateAggregationBuffer(currentBuffer, currentRow)
}
val mapIter = map.iterator()
while (mapIter.hasNext) {
mapIter.next()
}
map.free()
}
def runBenchmark(numKeys: Int, useUnsafe: Boolean, useHeap: Boolean): Unit = {
rand.setSeed(42)
if (useUnsafe) {
benchmarkManagedMemory(numKeys, useHeap)
} else {
benchmarkJavaObjects(numKeys)
}
}
def main(args: Array[String]): Unit = {
val NUM_WARMUPS = 2
val NUM_ITERATIONS = 10
println("mode,numKey,operationsPerSecond")
for (
numKeys <- 0 to (1000 * 1000 * 30, 250000);
mode <- Seq("unsafe-offheap", "unsafe-heap", "java-objects")
) {
val useUnsafe = Set("unsafe-heap", "unsafe-offheap").contains(mode)
val useHeap = mode == "unsafe-heap"
for (_ <- 1 to NUM_WARMUPS) {
runBenchmark(numKeys, useUnsafe, useHeap)
}
for (_ <- 1 to NUM_ITERATIONS) {
System.gc()
val startTime = System.currentTimeMillis()
runBenchmark(numKeys, useUnsafe, useHeap)
val endTime = System.currentTimeMillis()
println(s"$mode,$numKeys,${(1.0 * numKeys) / (endTime - startTime) * 1000}")
}
}
}
}
scalaVersion := "2.10.4"
val SPARK_HOME = sys.env.getOrElse("SPARK_HOME", "/Users/joshrosen/Documents/spark")
lazy val sparkSql = ProjectRef(file(SPARK_HOME), "sql")
lazy val sparkCore = ProjectRef(file(SPARK_HOME), "core")
lazy val root = (project in file(".")).
settings(
aggregate in update := false,
fork in run := true,
// libraryDependencies += "yourkit" % "yourkit-api" % "version" from "file:///Applications/YourKit_Java_Profiler_2013_build_13088.app/lib/yjp-controller-api-redist.jar",
// javaOptions in run += "-agentpath:/Applications/YourKit_Java_Profiler_2013_build_13088.app/bin/mac/libyjpagent.jnilib=onexit=snapshot",
javaOptions in run ++= Seq(
"-Xmx1250M",
"-Xms1250M"
)
).dependsOn(sparkSql, sparkCore)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment