Created
May 18, 2015 04:35
-
-
Save JoshRosen/680ee530655941defcb2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.sql | |
import java.io.File | |
import org.apache.spark.sql.catalyst.expressions.GenericRow | |
import org.apache.spark.{SparkConf, SparkContext} | |
import org.apache.spark.rdd.RDD | |
import org.apache.spark.sql.types.StructType | |
import scala.util.Random | |
object AggregationBenchmark { | |
val numKeys = 10000 | |
val numRecords = numKeys * 1 | |
val numWarmups = 5 | |
val numRepetitions = 10 | |
var sc: SparkContext = _ | |
var sqlContext: SQLContext = _ | |
case class Record(key: String, c1: Int, c2: Long, c3: Double) | |
def setup(): DataFrame = { | |
val sqlContext2 = sqlContext | |
import sqlContext2.implicits._ | |
val rdd: RDD[Row] = sc.parallelize (1 to numRecords, 100).mapPartitions { iter => | |
val rand = new Random(42) | |
val arr = new Array[Any](4) | |
val row = new GenericRow(arr) | |
iter.map { _ => | |
arr(0) = rand.nextString(8) | |
arr(1) = rand.nextInt() | |
arr(2) = rand.nextLong() | |
arr(3) = rand.nextDouble() | |
row | |
} | |
} | |
//rdd.count() | |
val df = | |
sqlContext.createDataFrame(rdd, StructType(Seq('key.string, 'c1.int, 'c2.long, 'c3.double))) | |
df.registerTempTable("data") | |
df | |
} | |
def main(args: Array[String]): Unit = { | |
val conf = new SparkConf().setMaster("local").setAppName("test") | |
.set("spark.sql.useSerializer2", "true") | |
.set("spark.shuffle.sort.bypassMergeThreshold", "0") | |
.set("spark.shuffle.manager", "tungsten-sort") | |
new File("eventlogs").mkdirs() | |
sc = new SparkContext(conf) | |
sqlContext = new SQLContext(sc) | |
setup() | |
val startTime = System.currentTimeMillis() | |
def runQuery(): Unit = { | |
sqlContext.sql("SELECT key, sum(c1), sum(c2), sum(c3) from data GROUP BY key").count() | |
} | |
val controller = new com.yourkit.api.Controller | |
(1 to numWarmups).foreach { _ => runQuery() } | |
controller.startCPUSampling(null) | |
//controller.startAllocationRecording(true, 10, false, 0, true, false) | |
(1 to numRepetitions).foreach { _ => runQuery() } | |
controller.stopCPUProfiling() | |
//controller.stopAllocationRecording() | |
val endTime = System.currentTimeMillis() | |
println("Average time: " + ((endTime - startTime) / (1.0 * numRepetitions))) | |
sc.stop() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment