Skip to content

Instantly share code, notes, and snippets.

@WojciechMazur
Last active July 7, 2022 15:33
Show Gist options
  • Save WojciechMazur/3bfc9cd7cf7924f5565fb3a93a608cb1 to your computer and use it in GitHub Desktop.
Save WojciechMazur/3bfc9cd7cf7924f5565fb3a93a608cb1 to your computer and use it in GitHub Desktop.
Scala Benchmarks sbt plguin
import sbt._
import sbt.Keys._
import xsbti.compile.CompileAnalysis
import xsbti.VirtualFileRef
import java.nio.file.Path
import scala.concurrent._
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
import java.nio.file.Files
import java.io._
object ScalaBenchmarksPlugin extends AutoPlugin {
override def trigger = allRequirements
val Scala2 = "2.13.8"
val Scala3 = "3.1.3"
val warmupIterations = 1 to 5
val benchIterations = 1 to 10
val clearPreviousResults = false
val benchmarkScalaAll = taskKey[Unit]("")
val benchmarkScala = taskKey[Unit]("")
lazy val benchResultsDirectory =
file("scala-benchmark-results").toPath().toAbsolutePath()
override def globalSettings: Seq[Setting[_]] = Seq(
(benchmarkScala / aggregate) := false,
(benchmarkScalaAll / aggregate) := false,
benchmarkScalaAll := {
var cState = state.value
val extracted = sbt.Project.extract(cState)
if (clearPreviousResults) {
if (Files.exists(benchResultsDirectory)) {
Files
.list(benchResultsDirectory)
.forEach(Files.delete(_))
}
}
def allScalaVersions(ref: ProjectRef): Seq[String] = {
cState.getSetting(ref / crossScalaVersions).getOrElse(Nil) ++
cState.getSetting(ref / scalaVersion)
}
// Set expected Scala versions and register candidate projects
val benchCandidates = extracted.structure.allProjectRefs
.filter { ref =>
def compilesWithExpectedVersion = {
allScalaVersions(ref)
.flatMap(CrossVersion.partialVersion)
.collectFirst {
case (2, 13) => true
case (3, _) => true
}
.isDefined
}
def isJVM = {
!ref.project.toLowerCase.contains("native") &&
!ref.project.toLowerCase.contains("js") &&
!ref.project.toLowerCase.contains("root") &&
cState
.getSetting(ref / crossVersion)
.exists {
case cross: Binary => cross.prefix.isEmpty
case cross: Full => cross.prefix.isEmpty
case _ => false
}
}
isJVM && compilesWithExpectedVersion
}
val projectDeps = extracted.structure.allProjectPairs.collect {
case (rp, ref) if benchCandidates.contains(ref) =>
ref -> rp.dependencies
.filter(benchCandidates.contains(_))
.map(_.project)
.toSet
}.toMap
@annotation.tailrec
def orderProjects(
acc: List[ProjectRef],
done: Set[ProjectRef],
remaining: Set[ProjectRef]
): List[ProjectRef] =
if (remaining.isEmpty) acc
else {
val (next, nextRemaining) =
remaining.partition { ref => projectDeps(ref).diff(done).isEmpty }
orderProjects(
acc = acc ++ next.toList,
done = done ++ next,
remaining = nextRemaining
)
}
orderProjects(Nil, Set.empty, projectDeps.map(_._1).toSet)
.foreach { ref =>
def isIgnoredInScala2 =
ref.project.contains("Dotty") || ref.project.endsWith("_3")
val scalaVersions = allScalaVersions(ref)
.flatMap(CrossVersion.partialVersion)
.distinct
.collect {
case (2, 13) if !isIgnoredInScala2 => Scala2
case (3, _) => Scala3
}
.foreach { version =>
println(
s"Run Scala benchmarks in ${ref.project} with Scala ${version}"
)
import sbt.internal.util.NoPosition
// Fails in some projects, use ++2.13.8;benchmarkScalaAll;++3.1.8;benchmarkScalaAll
// val setVersion = Seq(
// (ref / scalaVersion).transform(_ => version, NoPosition)
// )
// cState = extracted.appendWithSession(setVersion, cState)
// state := cState
val Some((newState, result)) =
sbt.Project.runTask(ref / benchmarkScala, cState)
cState = newState
}
}
}
)
override def projectSettings: Seq[Setting[_]] = Seq(
benchmarkScala := {
var currentState = state.value
val scalaVersion = Keys.scalaVersion.value
val thisProject = thisProjectRef.value
val name = moduleName.value
def evalTask(task: TaskKey[_]): ExecutionTime = {
val scopedTask = thisProject / task
val evalStart = System.currentTimeMillis()
val result = sbt.Project.runTask(scopedTask, currentState)
val took = ExecutionTime(System.currentTimeMillis() - evalStart)
val Some((newState, evalResult)) = result
evalResult.toEither match {
case Left(error) =>
throw new RuntimeException(s"Failed to run ${scopedTask} - $error")
case Right(_) =>
currentState = newState
took
}
}
def run[T, U](
label: String
)(onEachRun: () => T)(postProcess: Seq[T] => RunStats): RunResults = {
val warmupResults =
for (i <- warmupIterations)
yield {
println(s"$label :: warmup $i/${warmupIterations.last}")
onEachRun()
}
val results =
for (i <- benchIterations)
yield {
println(s"$label :: measure $i/${benchIterations.last}")
onEachRun()
}
RunResults(
postProcess(warmupResults),
postProcess(results)
)
}
def getSourceInfo(sources: Seq[File]) = sources
.foldLeft(SourcesInfo(0, 0)) { case (SourcesInfo(files, loc), sourceFile) =>
SourcesInfo(
files = files + 1,
linesOfCode = loc + scala.io.Source.fromFile(sourceFile).getLines.size
)
}
val sourceInfo = getSourceInfo((Compile / sources).value)
val testSourceInfo = getSourceInfo((Test / sources).value)
def postProcess(results: Seq[ResultWithMemoryMetrics[ExecutionTime]]) = {
val (executionTimes, memoryStats) =
results.flatMap(ResultWithMemoryMetrics.unapply(_)).unzip
val (heapMemoryUsage, offHeapMemoryUsage) =
memoryStats.flatMap(MemoryUsageStats.unapply(_)).unzip
RunStats(
executionTime = ExecutionTimes(executionTimes),
memoryUsage = memoryStats
)
}
if (sourceInfo.files == 0 && testSourceInfo.files == 0) {
println(s"No sources in $name, skipping")
} else if (
Files.exists(
benchResultsDirectory
.resolve(s"execution-times-$name-$scalaVersion.csv")
)
) {
println(s"Already executed benchmarks for $name")
} else {
println(
s"Files in $name: sources=${sourceInfo.files}, test sources=${testSourceInfo.files}"
)
val compileResults = run(s"$name/Compile/compile") { () =>
if (sourceInfo.files == 0)
ResultWithMemoryMetrics(ExecutionTime(0), MemoryUsageStats.empty)
else {
evalTask(Compile / clean)
withMemoryMetrics {
evalTask(Compile / compile)
}
}
}(postProcess)
val compileTestsResults = run(s"$name/Test/compile") { () =>
if (testSourceInfo.files == 0)
ResultWithMemoryMetrics(ExecutionTime(0), MemoryUsageStats.empty)
else {
evalTask(Test / clean)
withMemoryMetrics {
evalTask(Test / compile)
}
}
}(postProcess)
val executeTestsResults = run(s"$name/Test/test") { () =>
if (testSourceInfo.files == 0)
ResultWithMemoryMetrics(ExecutionTime(0), MemoryUsageStats.empty)
else {
evalTask(Test / compile)
withMemoryMetrics {
evalTask(Test / test)
}
}
}(postProcess)
val results = ProjectResults(
projectName = name,
scalaVersion = scalaVersion,
compile = compileResults,
compileTests = compileTestsResults,
executeTests = executeTestsResults,
sourceInfo = sourceInfo,
testSourceInfo = testSourceInfo
)
saveResults(results, benchResultsDirectory)
println(s"Results saved to $benchResultsDirectory")
}
}
)
case class ProjectResults(
projectName: String,
scalaVersion: String,
compile: RunResults,
compileTests: RunResults,
executeTests: RunResults,
sourceInfo: SourcesInfo,
testSourceInfo: SourcesInfo
) {
override def toString(): String =
s"""ProjectResults{
|project: ${projectName}
|sources:
| files: ${sourceInfo.files}
| LOC: ${sourceInfo.linesOfCode}
|compile:
| warmup: ${compile.warmupResults}
| bench: ${compile.results}
|testSources:
| files: ${testSourceInfo.files}
| LOC: ${testSourceInfo.linesOfCode}
|compileTests:
| warmup: ${compileTests.warmupResults}
| bench: ${compileTests.results}
|executeTests:
| warmup: ${executeTests.warmupResults}
| bench: ${executeTests.results}
|}""".stripMargin
}
case class SourcesInfo(files: Int, linesOfCode: Int)
case class RunResults(
warmupResults: RunStats,
results: RunStats
)
case class RunStats(
executionTime: ExecutionTimes,
memoryUsage: Seq[MemoryUsageStats]
) {
lazy val overallMemoryUsage = MemoryUsageStats(
memoryUsage
.map(_.heapMemoryUsage)
.foldLeft(MemoryMeasurments(Nil))(_ + _),
memoryUsage
.map(_.offHeapMemoryUsage)
.foldLeft(MemoryMeasurments(Nil))(_ + _)
)
}
case class ExecutionTime(timeMs: Long) extends AnyVal
case class ExecutionTimes(metrics: Seq[ExecutionTime]) extends AnyVal {
def ops = NumericOps(metrics)(_.timeMs)
}
case class NumericOps[Wrapper, Value: Ordering: Numeric](
metrics: Seq[Wrapper]
)(selector: Wrapper => Value) {
def numeric = implicitly[Numeric[Value]]
lazy val sorted = metrics.sortBy(selector).toVector
lazy val min: Value =
if (metrics.isEmpty) numeric.zero
else selector(metrics.minBy(selector))
lazy val max: Value =
if (metrics.isEmpty) numeric.zero
else selector(metrics.maxBy(selector))
lazy val average: Double =
if (metrics.isEmpty) 0.0
else
numeric.toDouble(metrics.foldLeft(numeric.zero) { case (acc, v) =>
numeric.plus(acc, selector(v))
}) / metrics.size
def percentile(p: Int): Value = if (metrics.isEmpty) numeric.zero
else {
require(p >= 0 && p <= 100)
val n = Math.floor(p / 100.0 * metrics.size).toInt
selector(sorted(n))
}
}
case class MemoryUsageStats(
heapMemoryUsage: MemoryMeasurments,
offHeapMemoryUsage: MemoryMeasurments
)
object MemoryUsageStats {
val empty = MemoryUsageStats(MemoryMeasurments(Nil), MemoryMeasurments(Nil))
}
case class MemoryMeasurments(metrics: Seq[MemoryMeasurment]) extends AnyVal {
def ops = NumericOps(metrics)(_.memoryUsedMb)
def +(other: MemoryMeasurments) = MemoryMeasurments(
metrics ++ other.metrics
)
}
case class MemoryMeasurment(memoryUsedMb: Long) extends AnyVal
class MemoryMetricsCollector extends Thread {
import java.lang.management.ManagementFactory
import collection.mutable.ListBuffer
val Interval = 100.millis
private lazy val memoryMX = ManagementFactory.getMemoryMXBean()
private lazy val heapMemoryUsage = ListBuffer.empty[MemoryMeasurment]
private lazy val offHeapMemoryUsage = ListBuffer.empty[MemoryMeasurment]
@volatile private var isCollecting = false
@volatile private var ready = false
def startCollecting() = synchronized {
while (isCollecting || !ready) wait(Interval.toMillis / 4)
heapMemoryUsage.clear()
offHeapMemoryUsage.clear()
isCollecting = true
}
def stopCollecting(): MemoryUsageStats = synchronized {
val metrics = MemoryUsageStats(
heapMemoryUsage = MemoryMeasurments(heapMemoryUsage.toSeq),
offHeapMemoryUsage = MemoryMeasurments(offHeapMemoryUsage.toSeq)
)
isCollecting = false
notifyAll()
metrics
}
override def run(): Unit = {
ready = true
while (!Thread.interrupted()) {
if (isCollecting) {
def toMeasurment(usageBytes: Long) =
MemoryMeasurment(usageBytes / 1024 / 1024)
heapMemoryUsage += toMeasurment(
memoryMX.getHeapMemoryUsage().getUsed()
)
offHeapMemoryUsage += toMeasurment(
memoryMX.getNonHeapMemoryUsage().getUsed()
)
try Thread.sleep(Interval.toMillis)
catch { case _: InterruptedException => return () }
}
}
}
}
case class ResultWithMemoryMetrics[T](
result: T,
memoryMetrics: MemoryUsageStats
)
def withMemoryMetrics[T](fn: => T): ResultWithMemoryMetrics[T] = {
val collector = new MemoryMetricsCollector()
collector.start()
try {
collector.startCollecting()
val result = fn
val metrics = collector.stopCollecting()
ResultWithMemoryMetrics(result, metrics)
} finally collector.interrupt()
}
def saveResults(results: ProjectResults, targetDirectory: Path): Unit = {
if (!Files.exists(targetDirectory)) Files.createDirectories(targetDirectory)
def writeOrAppendCSV(
fileName: String,
allowAppend: Boolean
)(header: Product)(values: Seq[Product]): Unit = {
require(values.forall(_.productArity == header.productArity))
val separator = ";"
val file = targetDirectory.resolve(fileName).toFile()
val append = allowAppend && file.exists() && file.length() > 0
val out = new DataOutputStream(new FileOutputStream(file, append))
def writeLine(v: Product) =
out.writeBytes(v.productIterator.mkString("", separator, "\n"))
try {
if (!append) writeLine(header)
values.foreach(writeLine)
} finally out.close()
}
def writeCSV(
fileName: String
)(header: Product)(values: Seq[Product]) =
writeOrAppendCSV(fileName, allowAppend = false)(header)(values)
def appendCSV(
fileName: String
)(header: Product)(values: Seq[Product]) =
writeOrAppendCSV(fileName, allowAppend = true)(header)(values)
appendCSV(s"summary-${results.scalaVersion}.csv") {
(
"project",
"source_files",
"source_lines",
"min_compilation_time",
"max_compilation_time",
"avg_compilation_time",
"min_memory_usage",
"max_memory_usage",
"p90_memory_usage",
"test_source_files",
"test_source_lines",
"tests_min_compilation_time",
"tests_max_compilation_time",
"tests_avg_compilation_time",
"tests_min_memory_usage",
"tests_max_memory_usage",
"tests_p90_memory_usage"
)
} {
val compile = results.compile.results
val compileHeap = compile.overallMemoryUsage.heapMemoryUsage
val testCompile = results.compileTests.results
val testsHeap = testCompile.overallMemoryUsage.heapMemoryUsage
Seq(
(
results.projectName,
results.sourceInfo.files,
results.sourceInfo.linesOfCode,
compile.executionTime.ops.min,
compile.executionTime.ops.max,
compile.executionTime.ops.average,
compileHeap.ops.min,
compileHeap.ops.max,
compileHeap.ops.percentile(90),
results.testSourceInfo.files,
results.testSourceInfo.linesOfCode,
testCompile.executionTime.ops.min,
testCompile.executionTime.ops.max,
testCompile.executionTime.ops.average,
testsHeap.ops.min,
testsHeap.ops.max,
testsHeap.ops.percentile(90)
)
)
}
if (results.testSourceInfo.files > 0) {
appendCSV(s"summary-${results.scalaVersion}.csv") {
(
"project",
"test_source_files",
"test_source_lines",
"tests_min_execute_time",
"tests_max_execute_time",
"tests_avg_execute_time",
"tests_min_execute_memory_usage",
"tests_max_execute_memory_usage",
"tests_p90_execute_memory_usage"
)
} {
val test = results.executeTests.results
val testsHeap = test.overallMemoryUsage.heapMemoryUsage
Seq(
(
results.projectName,
results.testSourceInfo.files,
results.testSourceInfo.linesOfCode,
test.executionTime.ops.min,
test.executionTime.ops.max,
test.executionTime.ops.average,
testsHeap.ops.min,
testsHeap.ops.max,
testsHeap.ops.percentile(90)
)
)
}
}
writeCSV(
s"execution-times-${results.projectName}-${results.scalaVersion}.csv"
) {
(
"iteration",
"compilation_time_ms",
"test_compilation_time_ms",
"test_execution_time_ms"
)
} {
def extract(selector: RunResults => RunStats, isWarmup: Boolean) = {
val zipped = selector(results.compile).executionTime.metrics
.zip(selector(results.compileTests).executionTime.metrics)
.zip(selector(results.executeTests).executionTime.metrics)
zipped.zipWithIndex.map { case (((compile, testsCompile), testsExecute), i) =>
val idx = i + 1
(
if (isWarmup) s"warmup-$idx" else idx,
compile.timeMs,
testsCompile.timeMs,
testsExecute.timeMs
)
}
}
extract(_.warmupResults, isWarmup = true) ++
extract(_.results, isWarmup = false)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment