Last active
August 29, 2015 14:19
-
-
Save JoshRosen/6181f667bd69c85c9529 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.sql | |
import org.apache.spark.unsafe.memory.{MemoryAllocator, TaskMemoryManager, ExecutorMemoryManager} | |
import scala.util.Random | |
import org.apache.spark.sql.catalyst.expressions._ | |
import org.apache.spark.sql.types._ | |
/** | |
* This benchmark inserts records into an aggregation hash map. | |
* | |
* Each record has the following schema: | |
* | |
* (10 character string, int, double, long) | |
* | |
* For the hash map entries, the grouping key is the first string column and the values are the | |
* remaining numeric columns. | |
* | |
* Because the keys are completely random, they should pretty much all be distinct, so this is a | |
* good way to stress-test how the maps perform at large sizes. | |
*/ | |
object AggregationHashMapBenchmark { | |
private val rowArr = new Array[Any](4) | |
private val row = new GenericRow(rowArr) | |
private val rand = new Random(42) | |
/** | |
* Generates a new random row. For efficiency, each call returns the same mutable row. | |
*/ | |
def randomRow(): Row = { | |
rowArr(0) = UTF8String(rand.nextString(10).toString) | |
rowArr(1) = rand.nextInt() | |
rowArr(2) = rand.nextLong() | |
rowArr(3) = rand.nextDouble() | |
row | |
} | |
/** | |
* Generates the grouping projection for a given row (in this case, a new row containing only | |
* the string column). | |
*/ | |
def groupProjection(row: Row): Row = { | |
new GenericRow(Array[Any](row.get(0))) | |
} | |
/** The schema of the map values (aggegation buffers) */ | |
val aggregationBufferTypes = IntegerType :: LongType :: DoubleType :: Nil | |
/** Generates new empty values */ | |
def emptyAggregationBuffer(): MutableRow = { | |
new SpecificMutableRow(aggregationBufferTypes) | |
} | |
/** | |
* Updates an aggregation buffer by adding the given row's values to it. | |
*/ | |
def updateAggregationBuffer(buffer: MutableRow, row: Row): Unit = { | |
buffer.setInt(0, buffer.getInt(0) + row.getInt(1)) | |
buffer.setLong(1, buffer.getLong(1) + row.getLong(2)) | |
buffer.setDouble(2, buffer.getDouble(2) + row.getDouble(3)) | |
} | |
/** | |
* Run the benchmark using a hashmap implementation based on Java objects. | |
*/ | |
def benchmarkJavaObjects(numKeys: Int): Unit = { | |
val map = new java.util.HashMap[Row, MutableRow]() | |
var i = 0 | |
while (i < numKeys) { | |
i += 1 | |
val currentRow = randomRow() | |
val currentGroup = groupProjection(currentRow) | |
var currentBuffer = map.get(currentGroup) | |
if (currentBuffer == null) { | |
currentBuffer = emptyAggregationBuffer() | |
map.put(currentGroup, currentBuffer) | |
} | |
updateAggregationBuffer(currentBuffer, currentRow) | |
} | |
val mapIter = map.entrySet().iterator() | |
while (mapIter.hasNext) { | |
mapIter.next() | |
} | |
} | |
/** | |
* Run the benchmark using a hashmap implementation that uses managed memory. | |
*/ | |
def benchmarkManagedMemory(numKeys: Int, useHeap: Boolean): Unit = { | |
val allocator = if (useHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE | |
val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(allocator)) | |
val map = new UnsafeFixedWidthAggregationMap( | |
emptyAggregationBuffer(), | |
StructType(Seq( | |
StructField("a", IntegerType), | |
StructField("b", LongType), | |
StructField("c", DoubleType))), | |
StructType(Seq( | |
StructField("d", StringType))), | |
memoryManager, | |
1024, // initial map size | |
false) // disable perf. metrics | |
var i = 0 | |
while (i < numKeys) { | |
i += 1 | |
val currentRow = randomRow() | |
val currentGroup = groupProjection(currentRow) | |
val currentBuffer = map.getAggregationBuffer(currentGroup) | |
updateAggregationBuffer(currentBuffer, currentRow) | |
} | |
val mapIter = map.iterator() | |
while (mapIter.hasNext) { | |
mapIter.next() | |
} | |
map.free() | |
} | |
def runBenchmark(numKeys: Int, useUnsafe: Boolean, useHeap: Boolean): Unit = { | |
rand.setSeed(42) | |
if (useUnsafe) { | |
benchmarkManagedMemory(numKeys, useHeap) | |
} else { | |
benchmarkJavaObjects(numKeys) | |
} | |
} | |
def main(args: Array[String]): Unit = { | |
val NUM_WARMUPS = 2 | |
val NUM_ITERATIONS = 10 | |
println("mode,numKey,operationsPerSecond") | |
for ( | |
numKeys <- 0 to (1000 * 1000 * 30, 250000); | |
mode <- Seq("unsafe-offheap", "unsafe-heap", "java-objects") | |
) { | |
val useUnsafe = Set("unsafe-heap", "unsafe-offheap").contains(mode) | |
val useHeap = mode == "unsafe-heap" | |
for (_ <- 1 to NUM_WARMUPS) { | |
runBenchmark(numKeys, useUnsafe, useHeap) | |
} | |
for (_ <- 1 to NUM_ITERATIONS) { | |
System.gc() | |
val startTime = System.currentTimeMillis() | |
runBenchmark(numKeys, useUnsafe, useHeap) | |
val endTime = System.currentTimeMillis() | |
println(s"$mode,$numKeys,${(1.0 * numKeys) / (endTime - startTime) * 1000}") | |
} | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
scalaVersion := "2.10.4" | |
val SPARK_HOME = sys.env.getOrElse("SPARK_HOME", "/Users/joshrosen/Documents/spark") | |
lazy val sparkSql = ProjectRef(file(SPARK_HOME), "sql") | |
lazy val sparkCore = ProjectRef(file(SPARK_HOME), "core") | |
lazy val root = (project in file(".")). | |
settings( | |
aggregate in update := false, | |
fork in run := true, | |
// libraryDependencies += "yourkit" % "yourkit-api" % "version" from "file:///Applications/YourKit_Java_Profiler_2013_build_13088.app/lib/yjp-controller-api-redist.jar", | |
// javaOptions in run += "-agentpath:/Applications/YourKit_Java_Profiler_2013_build_13088.app/bin/mac/libyjpagent.jnilib=onexit=snapshot", | |
javaOptions in run ++= Seq( | |
"-Xmx1250M", | |
"-Xms1250M" | |
) | |
).dependsOn(sparkSql, sparkCore) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment