Last active
July 7, 2022 15:33
-
-
Save WojciechMazur/3bfc9cd7cf7924f5565fb3a93a608cb1 to your computer and use it in GitHub Desktop.
Scala Benchmarks sbt plguin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sbt._ | |
import sbt.Keys._ | |
import xsbti.compile.CompileAnalysis | |
import xsbti.VirtualFileRef | |
import java.nio.file.Path | |
import scala.concurrent._ | |
import scala.concurrent.duration._ | |
import scala.concurrent.ExecutionContext.Implicits.global | |
import java.nio.file.Files | |
import java.io._ | |
object ScalaBenchmarksPlugin extends AutoPlugin { | |
override def trigger = allRequirements | |
val Scala2 = "2.13.8" | |
val Scala3 = "3.1.3" | |
val warmupIterations = 1 to 5 | |
val benchIterations = 1 to 10 | |
val clearPreviousResults = false | |
val benchmarkScalaAll = taskKey[Unit]("") | |
val benchmarkScala = taskKey[Unit]("") | |
lazy val benchResultsDirectory = | |
file("scala-benchmark-results").toPath().toAbsolutePath() | |
override def globalSettings: Seq[Setting[_]] = Seq( | |
(benchmarkScala / aggregate) := false, | |
(benchmarkScalaAll / aggregate) := false, | |
benchmarkScalaAll := { | |
var cState = state.value | |
val extracted = sbt.Project.extract(cState) | |
if (clearPreviousResults) { | |
if (Files.exists(benchResultsDirectory)) { | |
Files | |
.list(benchResultsDirectory) | |
.forEach(Files.delete(_)) | |
} | |
} | |
def allScalaVersions(ref: ProjectRef): Seq[String] = { | |
cState.getSetting(ref / crossScalaVersions).getOrElse(Nil) ++ | |
cState.getSetting(ref / scalaVersion) | |
} | |
// Set expected Scala versions and register candidate projects | |
val benchCandidates = extracted.structure.allProjectRefs | |
.filter { ref => | |
def compilesWithExpectedVersion = { | |
allScalaVersions(ref) | |
.flatMap(CrossVersion.partialVersion) | |
.collectFirst { | |
case (2, 13) => true | |
case (3, _) => true | |
} | |
.isDefined | |
} | |
def isJVM = { | |
!ref.project.toLowerCase.contains("native") && | |
!ref.project.toLowerCase.contains("js") && | |
!ref.project.toLowerCase.contains("root") && | |
cState | |
.getSetting(ref / crossVersion) | |
.exists { | |
case cross: Binary => cross.prefix.isEmpty | |
case cross: Full => cross.prefix.isEmpty | |
case _ => false | |
} | |
} | |
isJVM && compilesWithExpectedVersion | |
} | |
val projectDeps = extracted.structure.allProjectPairs.collect { | |
case (rp, ref) if benchCandidates.contains(ref) => | |
ref -> rp.dependencies | |
.filter(benchCandidates.contains(_)) | |
.map(_.project) | |
.toSet | |
}.toMap | |
@annotation.tailrec | |
def orderProjects( | |
acc: List[ProjectRef], | |
done: Set[ProjectRef], | |
remaining: Set[ProjectRef] | |
): List[ProjectRef] = | |
if (remaining.isEmpty) acc | |
else { | |
val (next, nextRemaining) = | |
remaining.partition { ref => projectDeps(ref).diff(done).isEmpty } | |
orderProjects( | |
acc = acc ++ next.toList, | |
done = done ++ next, | |
remaining = nextRemaining | |
) | |
} | |
orderProjects(Nil, Set.empty, projectDeps.map(_._1).toSet) | |
.foreach { ref => | |
def isIgnoredInScala2 = | |
ref.project.contains("Dotty") || ref.project.endsWith("_3") | |
val scalaVersions = allScalaVersions(ref) | |
.flatMap(CrossVersion.partialVersion) | |
.distinct | |
.collect { | |
case (2, 13) if !isIgnoredInScala2 => Scala2 | |
case (3, _) => Scala3 | |
} | |
.foreach { version => | |
println( | |
s"Run Scala benchmarks in ${ref.project} with Scala ${version}" | |
) | |
import sbt.internal.util.NoPosition | |
// Fails in some projects, use ++2.13.8;benchmarkScalaAll;++3.1.8;benchmarkScalaAll | |
// val setVersion = Seq( | |
// (ref / scalaVersion).transform(_ => version, NoPosition) | |
// ) | |
// cState = extracted.appendWithSession(setVersion, cState) | |
// state := cState | |
val Some((newState, result)) = | |
sbt.Project.runTask(ref / benchmarkScala, cState) | |
cState = newState | |
} | |
} | |
} | |
) | |
override def projectSettings: Seq[Setting[_]] = Seq( | |
benchmarkScala := { | |
var currentState = state.value | |
val scalaVersion = Keys.scalaVersion.value | |
val thisProject = thisProjectRef.value | |
val name = moduleName.value | |
def evalTask(task: TaskKey[_]): ExecutionTime = { | |
val scopedTask = thisProject / task | |
val evalStart = System.currentTimeMillis() | |
val result = sbt.Project.runTask(scopedTask, currentState) | |
val took = ExecutionTime(System.currentTimeMillis() - evalStart) | |
val Some((newState, evalResult)) = result | |
evalResult.toEither match { | |
case Left(error) => | |
throw new RuntimeException(s"Failed to run ${scopedTask} - $error") | |
case Right(_) => | |
currentState = newState | |
took | |
} | |
} | |
def run[T, U]( | |
label: String | |
)(onEachRun: () => T)(postProcess: Seq[T] => RunStats): RunResults = { | |
val warmupResults = | |
for (i <- warmupIterations) | |
yield { | |
println(s"$label :: warmup $i/${warmupIterations.last}") | |
onEachRun() | |
} | |
val results = | |
for (i <- benchIterations) | |
yield { | |
println(s"$label :: measure $i/${benchIterations.last}") | |
onEachRun() | |
} | |
RunResults( | |
postProcess(warmupResults), | |
postProcess(results) | |
) | |
} | |
def getSourceInfo(sources: Seq[File]) = sources | |
.foldLeft(SourcesInfo(0, 0)) { case (SourcesInfo(files, loc), sourceFile) => | |
SourcesInfo( | |
files = files + 1, | |
linesOfCode = loc + scala.io.Source.fromFile(sourceFile).getLines.size | |
) | |
} | |
val sourceInfo = getSourceInfo((Compile / sources).value) | |
val testSourceInfo = getSourceInfo((Test / sources).value) | |
def postProcess(results: Seq[ResultWithMemoryMetrics[ExecutionTime]]) = { | |
val (executionTimes, memoryStats) = | |
results.flatMap(ResultWithMemoryMetrics.unapply(_)).unzip | |
val (heapMemoryUsage, offHeapMemoryUsage) = | |
memoryStats.flatMap(MemoryUsageStats.unapply(_)).unzip | |
RunStats( | |
executionTime = ExecutionTimes(executionTimes), | |
memoryUsage = memoryStats | |
) | |
} | |
if (sourceInfo.files == 0 && testSourceInfo.files == 0) { | |
println(s"No sources in $name, skipping") | |
} else if ( | |
Files.exists( | |
benchResultsDirectory | |
.resolve(s"execution-times-$name-$scalaVersion.csv") | |
) | |
) { | |
println(s"Already executed benchmarks for $name") | |
} else { | |
println( | |
s"Files in $name: sources=${sourceInfo.files}, test sources=${testSourceInfo.files}" | |
) | |
val compileResults = run(s"$name/Compile/compile") { () => | |
if (sourceInfo.files == 0) | |
ResultWithMemoryMetrics(ExecutionTime(0), MemoryUsageStats.empty) | |
else { | |
evalTask(Compile / clean) | |
withMemoryMetrics { | |
evalTask(Compile / compile) | |
} | |
} | |
}(postProcess) | |
val compileTestsResults = run(s"$name/Test/compile") { () => | |
if (testSourceInfo.files == 0) | |
ResultWithMemoryMetrics(ExecutionTime(0), MemoryUsageStats.empty) | |
else { | |
evalTask(Test / clean) | |
withMemoryMetrics { | |
evalTask(Test / compile) | |
} | |
} | |
}(postProcess) | |
val executeTestsResults = run(s"$name/Test/test") { () => | |
if (testSourceInfo.files == 0) | |
ResultWithMemoryMetrics(ExecutionTime(0), MemoryUsageStats.empty) | |
else { | |
evalTask(Test / compile) | |
withMemoryMetrics { | |
evalTask(Test / test) | |
} | |
} | |
}(postProcess) | |
val results = ProjectResults( | |
projectName = name, | |
scalaVersion = scalaVersion, | |
compile = compileResults, | |
compileTests = compileTestsResults, | |
executeTests = executeTestsResults, | |
sourceInfo = sourceInfo, | |
testSourceInfo = testSourceInfo | |
) | |
saveResults(results, benchResultsDirectory) | |
println(s"Results saved to $benchResultsDirectory") | |
} | |
} | |
) | |
case class ProjectResults( | |
projectName: String, | |
scalaVersion: String, | |
compile: RunResults, | |
compileTests: RunResults, | |
executeTests: RunResults, | |
sourceInfo: SourcesInfo, | |
testSourceInfo: SourcesInfo | |
) { | |
override def toString(): String = | |
s"""ProjectResults{ | |
|project: ${projectName} | |
|sources: | |
| files: ${sourceInfo.files} | |
| LOC: ${sourceInfo.linesOfCode} | |
|compile: | |
| warmup: ${compile.warmupResults} | |
| bench: ${compile.results} | |
|testSources: | |
| files: ${testSourceInfo.files} | |
| LOC: ${testSourceInfo.linesOfCode} | |
|compileTests: | |
| warmup: ${compileTests.warmupResults} | |
| bench: ${compileTests.results} | |
|executeTests: | |
| warmup: ${executeTests.warmupResults} | |
| bench: ${executeTests.results} | |
|}""".stripMargin | |
} | |
case class SourcesInfo(files: Int, linesOfCode: Int) | |
case class RunResults( | |
warmupResults: RunStats, | |
results: RunStats | |
) | |
case class RunStats( | |
executionTime: ExecutionTimes, | |
memoryUsage: Seq[MemoryUsageStats] | |
) { | |
lazy val overallMemoryUsage = MemoryUsageStats( | |
memoryUsage | |
.map(_.heapMemoryUsage) | |
.foldLeft(MemoryMeasurments(Nil))(_ + _), | |
memoryUsage | |
.map(_.offHeapMemoryUsage) | |
.foldLeft(MemoryMeasurments(Nil))(_ + _) | |
) | |
} | |
case class ExecutionTime(timeMs: Long) extends AnyVal | |
case class ExecutionTimes(metrics: Seq[ExecutionTime]) extends AnyVal { | |
def ops = NumericOps(metrics)(_.timeMs) | |
} | |
case class NumericOps[Wrapper, Value: Ordering: Numeric]( | |
metrics: Seq[Wrapper] | |
)(selector: Wrapper => Value) { | |
def numeric = implicitly[Numeric[Value]] | |
lazy val sorted = metrics.sortBy(selector).toVector | |
lazy val min: Value = | |
if (metrics.isEmpty) numeric.zero | |
else selector(metrics.minBy(selector)) | |
lazy val max: Value = | |
if (metrics.isEmpty) numeric.zero | |
else selector(metrics.maxBy(selector)) | |
lazy val average: Double = | |
if (metrics.isEmpty) 0.0 | |
else | |
numeric.toDouble(metrics.foldLeft(numeric.zero) { case (acc, v) => | |
numeric.plus(acc, selector(v)) | |
}) / metrics.size | |
def percentile(p: Int): Value = if (metrics.isEmpty) numeric.zero | |
else { | |
require(p >= 0 && p <= 100) | |
val n = Math.floor(p / 100.0 * metrics.size).toInt | |
selector(sorted(n)) | |
} | |
} | |
case class MemoryUsageStats( | |
heapMemoryUsage: MemoryMeasurments, | |
offHeapMemoryUsage: MemoryMeasurments | |
) | |
object MemoryUsageStats { | |
val empty = MemoryUsageStats(MemoryMeasurments(Nil), MemoryMeasurments(Nil)) | |
} | |
case class MemoryMeasurments(metrics: Seq[MemoryMeasurment]) extends AnyVal { | |
def ops = NumericOps(metrics)(_.memoryUsedMb) | |
def +(other: MemoryMeasurments) = MemoryMeasurments( | |
metrics ++ other.metrics | |
) | |
} | |
case class MemoryMeasurment(memoryUsedMb: Long) extends AnyVal | |
class MemoryMetricsCollector extends Thread { | |
import java.lang.management.ManagementFactory | |
import collection.mutable.ListBuffer | |
val Interval = 100.millis | |
private lazy val memoryMX = ManagementFactory.getMemoryMXBean() | |
private lazy val heapMemoryUsage = ListBuffer.empty[MemoryMeasurment] | |
private lazy val offHeapMemoryUsage = ListBuffer.empty[MemoryMeasurment] | |
@volatile private var isCollecting = false | |
@volatile private var ready = false | |
def startCollecting() = synchronized { | |
while (isCollecting || !ready) wait(Interval.toMillis / 4) | |
heapMemoryUsage.clear() | |
offHeapMemoryUsage.clear() | |
isCollecting = true | |
} | |
def stopCollecting(): MemoryUsageStats = synchronized { | |
val metrics = MemoryUsageStats( | |
heapMemoryUsage = MemoryMeasurments(heapMemoryUsage.toSeq), | |
offHeapMemoryUsage = MemoryMeasurments(offHeapMemoryUsage.toSeq) | |
) | |
isCollecting = false | |
notifyAll() | |
metrics | |
} | |
override def run(): Unit = { | |
ready = true | |
while (!Thread.interrupted()) { | |
if (isCollecting) { | |
def toMeasurment(usageBytes: Long) = | |
MemoryMeasurment(usageBytes / 1024 / 1024) | |
heapMemoryUsage += toMeasurment( | |
memoryMX.getHeapMemoryUsage().getUsed() | |
) | |
offHeapMemoryUsage += toMeasurment( | |
memoryMX.getNonHeapMemoryUsage().getUsed() | |
) | |
try Thread.sleep(Interval.toMillis) | |
catch { case _: InterruptedException => return () } | |
} | |
} | |
} | |
} | |
case class ResultWithMemoryMetrics[T]( | |
result: T, | |
memoryMetrics: MemoryUsageStats | |
) | |
def withMemoryMetrics[T](fn: => T): ResultWithMemoryMetrics[T] = { | |
val collector = new MemoryMetricsCollector() | |
collector.start() | |
try { | |
collector.startCollecting() | |
val result = fn | |
val metrics = collector.stopCollecting() | |
ResultWithMemoryMetrics(result, metrics) | |
} finally collector.interrupt() | |
} | |
def saveResults(results: ProjectResults, targetDirectory: Path): Unit = { | |
if (!Files.exists(targetDirectory)) Files.createDirectories(targetDirectory) | |
def writeOrAppendCSV( | |
fileName: String, | |
allowAppend: Boolean | |
)(header: Product)(values: Seq[Product]): Unit = { | |
require(values.forall(_.productArity == header.productArity)) | |
val separator = ";" | |
val file = targetDirectory.resolve(fileName).toFile() | |
val append = allowAppend && file.exists() && file.length() > 0 | |
val out = new DataOutputStream(new FileOutputStream(file, append)) | |
def writeLine(v: Product) = | |
out.writeBytes(v.productIterator.mkString("", separator, "\n")) | |
try { | |
if (!append) writeLine(header) | |
values.foreach(writeLine) | |
} finally out.close() | |
} | |
def writeCSV( | |
fileName: String | |
)(header: Product)(values: Seq[Product]) = | |
writeOrAppendCSV(fileName, allowAppend = false)(header)(values) | |
def appendCSV( | |
fileName: String | |
)(header: Product)(values: Seq[Product]) = | |
writeOrAppendCSV(fileName, allowAppend = true)(header)(values) | |
appendCSV(s"summary-${results.scalaVersion}.csv") { | |
( | |
"project", | |
"source_files", | |
"source_lines", | |
"min_compilation_time", | |
"max_compilation_time", | |
"avg_compilation_time", | |
"min_memory_usage", | |
"max_memory_usage", | |
"p90_memory_usage", | |
"test_source_files", | |
"test_source_lines", | |
"tests_min_compilation_time", | |
"tests_max_compilation_time", | |
"tests_avg_compilation_time", | |
"tests_min_memory_usage", | |
"tests_max_memory_usage", | |
"tests_p90_memory_usage" | |
) | |
} { | |
val compile = results.compile.results | |
val compileHeap = compile.overallMemoryUsage.heapMemoryUsage | |
val testCompile = results.compileTests.results | |
val testsHeap = testCompile.overallMemoryUsage.heapMemoryUsage | |
Seq( | |
( | |
results.projectName, | |
results.sourceInfo.files, | |
results.sourceInfo.linesOfCode, | |
compile.executionTime.ops.min, | |
compile.executionTime.ops.max, | |
compile.executionTime.ops.average, | |
compileHeap.ops.min, | |
compileHeap.ops.max, | |
compileHeap.ops.percentile(90), | |
results.testSourceInfo.files, | |
results.testSourceInfo.linesOfCode, | |
testCompile.executionTime.ops.min, | |
testCompile.executionTime.ops.max, | |
testCompile.executionTime.ops.average, | |
testsHeap.ops.min, | |
testsHeap.ops.max, | |
testsHeap.ops.percentile(90) | |
) | |
) | |
} | |
if (results.testSourceInfo.files > 0) { | |
appendCSV(s"summary-${results.scalaVersion}.csv") { | |
( | |
"project", | |
"test_source_files", | |
"test_source_lines", | |
"tests_min_execute_time", | |
"tests_max_execute_time", | |
"tests_avg_execute_time", | |
"tests_min_execute_memory_usage", | |
"tests_max_execute_memory_usage", | |
"tests_p90_execute_memory_usage" | |
) | |
} { | |
val test = results.executeTests.results | |
val testsHeap = test.overallMemoryUsage.heapMemoryUsage | |
Seq( | |
( | |
results.projectName, | |
results.testSourceInfo.files, | |
results.testSourceInfo.linesOfCode, | |
test.executionTime.ops.min, | |
test.executionTime.ops.max, | |
test.executionTime.ops.average, | |
testsHeap.ops.min, | |
testsHeap.ops.max, | |
testsHeap.ops.percentile(90) | |
) | |
) | |
} | |
} | |
writeCSV( | |
s"execution-times-${results.projectName}-${results.scalaVersion}.csv" | |
) { | |
( | |
"iteration", | |
"compilation_time_ms", | |
"test_compilation_time_ms", | |
"test_execution_time_ms" | |
) | |
} { | |
def extract(selector: RunResults => RunStats, isWarmup: Boolean) = { | |
val zipped = selector(results.compile).executionTime.metrics | |
.zip(selector(results.compileTests).executionTime.metrics) | |
.zip(selector(results.executeTests).executionTime.metrics) | |
zipped.zipWithIndex.map { case (((compile, testsCompile), testsExecute), i) => | |
val idx = i + 1 | |
( | |
if (isWarmup) s"warmup-$idx" else idx, | |
compile.timeMs, | |
testsCompile.timeMs, | |
testsExecute.timeMs | |
) | |
} | |
} | |
extract(_.warmupResults, isWarmup = true) ++ | |
extract(_.results, isWarmup = false) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment