Created
September 21, 2017 20:51
-
-
Save pchiusano/71dd8c7c35057f6f453ea1fc2974debf to your computer and use it in GitHub Desktop.
Code for Scala World 2017 talk on eliminating interpreter overhead via partial evaluation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package scalaworld.interpreters | |
/* | |
This file shows a simple language, an interpreter, and two | |
partial evaluators for that language, along with a profiling suite. | |
*/ | |
trait Expr // denotes a Vector[Double] => Vector[Double] | |
object Expr { | |
/** Set register `d` (for "destination") equal to `n`. */ | |
case class Num(d: Int, n: Double) extends Expr | |
/** Set register `d` equal to register `i` + register `j`. */ | |
case class Plus(d: Int, i: Int, j: Int) extends Expr | |
/** Decrement register `d`. */ | |
case class Decr(d: Int) extends Expr | |
/** Set register `d` equal to register `i`. */ | |
case class Copy(d: Int, i: Int) extends Expr | |
/** Run the instructions in `es` in sequence. */ | |
case class Block(es: List[Expr]) extends Expr | |
/** Execute `p` repeatedly until the `haltIf0` register is 0. */ | |
case class Loop(haltIf0: Int, p: Expr) extends Expr | |
/** Some syntax - variadic `Block.apply`. */ | |
object Block { def apply(es: Expr*): Expr = Block(es.toList) } | |
/** | |
* A simple interpreter for `Expr`. For efficiency, this mutates an `Array[Double]` | |
* rather than transforming a `Vector[Double]`. Straightforward but inefficient. | |
*/ | |
def interpret(e: Expr, m: Array[Double]): Unit = e match { | |
case Num(d, n) => m(d) = n | |
case Decr(d) => m(d) = m(d) - 1.0 | |
case Plus(d, i, j) => m(d) = m(i) + m(j) | |
case Copy(d, i) => m(d) = m(i) | |
case Loop(haltIf0, p) => interpretLoop(haltIf0, p, m) | |
case Block(es) => interpretBlock(es, m) | |
} | |
// Notice that we have interpreter overhead _on each execution of the loop body_. | |
def interpretLoop(haltIf0: Int, p: Expr, m: Array[Double]): Unit = | |
while (!(m(haltIf0) == 0)) interpret(p, m) | |
@annotation.tailrec | |
def interpretBlock(es: List[Expr], m: Array[Double]): Unit = es match { | |
case Nil => () | |
case Block(es0) :: es => interpretBlock(es0 ++ es, m) | |
case e :: es => interpret(e, m); interpretBlock(es, m) | |
} | |
/** | |
* Here's a simple partial evaluatator. We curry the `interpret` function, | |
* but do all inspection of the syntax tree _before_ returning our | |
* compiled form, an `Array[Double] => Unit`. | |
* | |
* `partialEval` could be called `compile` - we are producing a compiled | |
* form with no interpreter overhead, as in the Futamura projections. | |
*/ | |
def partialEval(e: Expr): Array[Double] => Unit = e match { | |
case Num(d, n) => m => m(d) = n | |
case Decr(d) => m => m(d) = m(d) - 1.0 | |
// case Plus(d, i, j) if d == j => m => m(d) += m(i) | |
case Plus(d, i, j) => m => m(d) = m(i) + m(j) | |
case Copy(d, i) => m => m(d) = m(i) | |
case Loop(haltIf0, p) => | |
val compiledBody = partialEval(p) // very important, we compile the body once! | |
m => while (m(haltIf0) != 0.0) compiledBody(m) // ... and then execute it multiple times | |
case Block(es) => partialEvalBlock(es) | |
} | |
def partialEvalBlock(ps: List[Expr]): Array[Double] => Unit = ps match { | |
case List(e) => partialEval(e) | |
case p :: ps2 => | |
val cp = partialEval(p) | |
val cps = partialEvalBlock(ps2) | |
m => { cp(m); cps(m) } | |
} | |
// Performance of this approach is highly dependent on choice of compiled form. | |
// An `Array[Double] => Unit` may require computing array offsets and doing array | |
// bounds checks. To improve performance, we can move to a function that just | |
// takes a mutable record of `Double` values: | |
case class Machine(var r0: Double, var r1: Double, var r2: Double, var r3: Double) | |
// Our compiled form will be `Machine => Unit` for this second partial evaluator. | |
// Not being able to just use array offsets requires a bit more code than before. | |
object Machine { | |
def get(i: Int): Machine => Double = i match { | |
case 0 => _.r0 | |
case 1 => _.r1 | |
case 2 => _.r2 | |
case 3 => _.r3 | |
} | |
// experimented with this, doesn't seem to make a difference | |
//abstract class R { def apply(m: Machine): Double } | |
//def get(i: Int): R = i match { | |
// case 0 => new R { def apply(m: Machine) = m.r0 } | |
// case 1 => new R { def apply(m: Machine) = m.r1 } | |
// case 2 => new R { def apply(m: Machine) = m.r2 } | |
// case 3 => new R { def apply(m: Machine) = m.r3 } | |
//} | |
} | |
def partialEval2(e: Expr): Machine => Unit = e match { | |
case Num(d, n) => d match { | |
case 0 => m => m.r0 = n | |
case 1 => m => m.r1 = n | |
case 2 => m => m.r2 = n | |
case 3 => m => m.r3 = n | |
} | |
case Decr(d) => d match { | |
case 0 => m => m.r0 -= 1.0 | |
case 1 => m => m.r1 -= 1.0 | |
case 2 => m => m.r2 -= 1.0 | |
case 3 => m => m.r3 -= 1.0 | |
} | |
// can make a difference, suggests Machine.get(i) isn't reliably inlined by JIT | |
case Plus(1, 0, 1) => m => m.r1 += m.r0 | |
case Plus(d, i, j) if d == j => | |
val ci = Machine.get(i) | |
d match { | |
case 0 => m => m.r0 += ci(m) | |
case 1 => m => m.r1 += ci(m) | |
case 2 => m => m.r2 += ci(m) | |
case 3 => m => m.r3 += ci(m) | |
} | |
case Plus(d, i, j) => | |
val ci = Machine.get(i) | |
val cj = Machine.get(j) | |
d match { | |
case 0 => m => m.r0 = ci(m) + cj(m) | |
case 1 => m => m.r1 = ci(m) + cj(m) | |
case 2 => m => m.r2 = ci(m) + cj(m) | |
case 3 => m => m.r3 = ci(m) + cj(m) | |
} | |
case Copy(d, i) => | |
val ci = Machine.get(i) | |
d match { | |
case 0 => m => m.r0 = ci(m) | |
case 1 => m => m.r1 = ci(m) | |
case 2 => m => m.r2 = ci(m) | |
case 3 => m => m.r3 = ci(m) | |
} | |
// also can make a difference, suggests Machine.get compiled form isn't reliably inlined by JIT | |
case Loop(0, p) => | |
val compiledBody = partialEval2(p) | |
m => while (m.r0 != 0.0) compiledBody(m) | |
case Loop(haltIf0, p) => | |
val cHaltIf0 = Machine.get(haltIf0) | |
val compiledBody = partialEval2(p) | |
m => while (cHaltIf0(m) != 0.0) compiledBody(m) | |
case Block(es) => partialEvalBlock2(es) | |
} | |
def partialEvalBlock2(ps: List[Expr]): Machine => Unit = ps match { | |
case List(e) => partialEval2(e) | |
case p :: ps2 => | |
val cp = partialEval2(p) | |
val cps = partialEvalBlock2(ps2) | |
m => { cp(m); cps(m) } | |
} | |
} | |
object Ex extends App { | |
import Expr._ | |
import quickprofile.QuickProfile.{suite,profile} | |
def N = 1e6 + math.random.floor | |
val m = Array(0.0, 0.0, 0.0, 0.0) | |
// expects `n` in register 0, puts result in register 1 | |
val fib = Block( // var n = <fn param> | |
Num(1, 0.0), // var f1 = 0 | |
Num(2, 1.0), // var f2 = 1 | |
Loop(0, Block( // while (n != 0) { | |
Plus(3, 1, 2),// val tmp = f1 + f2 | |
Copy(1, 2), // f1 = f2 | |
Copy(2, 3), // f2 = tmp | |
Decr(0))) // n -= 1 | |
) // } | |
@annotation.tailrec | |
def fib(n: Double, f0: Double, f1: Double): Double = | |
if (n == 0) f0 | |
else fib(n - 1.0, f1, f0 + f1) | |
// Sums up the numbers 0 to `n`. | |
// Expects `n` in register 0, puts result in register 1. | |
val sumN = Block( | |
Num(1, 0.0), | |
Loop(0, Block( | |
Plus(1, 0, 1), | |
Decr(0) | |
)) | |
) | |
@annotation.tailrec | |
def sumN(n: Double, acc: Double): Double = | |
if (n == 0.0) acc | |
else sumN(n - 1.0, acc + n) | |
// Sanity check - let's make sure all implementations produce the same results | |
println { | |
println ("interpreted") | |
(0 until 10).map { i => | |
m(0) = i.toDouble | |
interpret(sumN, m) | |
m(1).toLong | |
}.mkString(" ") | |
} | |
println { | |
val csum = partialEval(sumN) | |
println ("partially-evaluated") | |
(0 until 10).map { i => | |
m(0) = i.toDouble | |
csum(m) | |
m(1).toLong | |
}.mkString(" ") | |
} | |
println { | |
val m = Machine(0,0,0,0) | |
val csum = partialEval2(sumN) | |
println ("partially-evaluated (2)") | |
(0 until 10).map { i => | |
m.r0 = i.toDouble | |
csum(m) | |
m.r1.toLong | |
}.mkString(" ") | |
} | |
println { | |
println ("native") | |
(0 until 10).map { i => sumN(i.toDouble, 0.0).toLong }.mkString(" ") | |
} | |
// Okay, now run the profiling suite | |
suite ( | |
{ val csum = partialEval2(sumN) | |
val m = Machine(0.0, 0.0, 0.0, 0.0) | |
profile("partially-evaluated (2)", 0.03) { | |
m.r0 = N | |
csum(m) | |
m.r1.toLong | |
} | |
}, | |
{ val m = Array.fill(4)(0.0) | |
profile("interpreted", 0.03) { | |
m(0) = N | |
interpret(sumN, m) | |
m(1).toLong | |
} | |
}, | |
profile("Scala", 0.03) { sumN(N, 0.0).toLong }, | |
{ val csum = partialEval(sumN) | |
val m = Array.fill(4)(0.0) | |
profile("partially-evaluated", 0.03) { | |
m(0) = N | |
csum(m) | |
m(1).toLong | |
} | |
} | |
) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package quickprofile | |
/** | |
* Simple-to-use, fast, and relatively accurate benchmarking functions. | |
* | |
* `profile` runs an individual benchmark and reports performance. | |
* `suite` runs a collection of benchmarks and reports relative performance. | |
* | |
* Unlike JMH, we do not require picking an arbitrary number of warmup iterations | |
* or recorded iterations (often chosen to be either too small, yielding | |
* inaccurate or wildly varying results, or too big, leading benchmarks to take | |
* forever and not be run as part of normal development). | |
* | |
* See `profile` docs for more details on the methodology. | |
* | |
* Example usage: {{{ | |
import QuickProfile.{suite, profile} | |
suite( | |
profile("loop1") { | |
val n = 1e6 + math.random | |
while (n > 0.0) n -=1 | |
n.toLong | |
}, | |
profile("loop2") { | |
val n = 1e6 + math.random; | |
(0 until 1000000).foreach { _ => n -= 1.0 } | |
n.toLong | |
}, | |
{ // setup for the benchmark, won't be measured | |
val nums = Vector.range(0, 1000000) | |
// okay, start measuring | |
profile("loop3") { | |
val n = 1e6 + math.random; | |
nums.foreach { _ => n -= 1.0 } | |
n.toLong | |
} | |
} | |
) | |
}}} | |
Which produces output like: {{{ | |
- loop1: 1.475 milliseconds (4.3% deviation, N=68, K = 0) | |
- loop2: 1.069 milliseconds (4.3% deviation, N=224, K = 0) | |
- loop3: 14.776 milliseconds (3.0% deviation, N=26, K = 0) | |
1.0 loop2 | |
1.37 loop1 | |
13.81 loop3 | |
}}} | |
*/ | |
object QuickProfile { | |
/** | |
* Run an action repeatedly, capturing profiling info, until the % deviation of | |
* timing info is less than `threshold` (closer to 0). Idea is to discover the | |
* steady state of performance that occurs when most hot spots have been JIT'd | |
* without needing to pick an arbitrary number of warmup iterations and trials | |
* (which are often either too small, leading to inaccurate results, or too big, | |
* leading to profiling taking way too long). | |
* | |
* This function increases N--the number of times `action` is invoked per | |
* iteration--until each iteration takes at least 100ms (so limited granularity | |
* of System.nanoTime is no longer an issue). Once reaching this point, | |
* it then gradually increase N exponentially (N = N * (1 + epsilon)) until percent | |
* deviation drops below `threshold`. This all tends to happen pretty quickly. | |
* | |
* The `action` must return a `Long`, preferably unique for each execution, and the | |
* sum of these numbers is threaded through the profiling computation to prevent the | |
* JVM from doing any heroic optimizations that would eliminate executions of `action`. | |
* | |
* One caveat: JVM optimizations are not totally deterministic, so running the same | |
* benchmark with a fresh JVM may reach a different steady state (though if performance | |
* is highly sensitive to this, it could be a good idea to find a different way of | |
* expressing your program such that performance is not as fragile). It's usually | |
* obvious from a handful of runs of a benchmark whether any nondeterminism of JVM | |
* optimizations is relevant for performance, but for maximum accuracy it can be a good | |
* idea to average results from multiple JVM runs. | |
*/ | |
def profile(label: String, threshold: Double = 0.05)(action: => Long): (String, Double) = { | |
var N = 16L | |
var i = 0 | |
var startTime = System.nanoTime | |
var stopTime = System.nanoTime | |
var sample = 1e9 | |
var K = 0L | |
var ok = true | |
var percentDeviation = Double.PositiveInfinity | |
while (ok) { | |
// try to increase N to get at least 100ms sample - | |
if (sample*N < 1e8) { // 1e8 nanos is 100ms | |
// do linear interpolation to guess N that will hit 100ms exactly | |
val N2 = N * (1e8 / (sample*N)).toLong | |
if ((N.toDouble / N2.toDouble - 1.0).abs < .15) | |
// we're close enough, stop interpolating and just grow N exponentially | |
N = (N.toDouble*1.2).toLong | |
else | |
// not close enough, so use the linear interpolation | |
N = N2 | |
} | |
// otherwise increase N gradually to decrease variance | |
else N = (N.toDouble*1.2).toLong | |
print(s"\r * $label: ${formatNanos(sample)}, N=$N, deviation=$percentDeviation%, target deviation: ${threshold*100}% ") | |
val Y = 10 // | |
val samples = (0 until Y) map { _ => | |
i = 0 ; val startTime = System.nanoTime | |
// note - we sum the `Long` values returned from each `action`, to ensure | |
// `action` cannot be optimized away | |
while (i < N) { K += action; i += 1 } | |
val stopTime = System.nanoTime | |
val sample = (stopTime - startTime) / N | |
print(" ") | |
System.gc() // try to minimize variance due to GC timing | |
sample | |
} | |
val mean = samples.sum / Y.toDouble | |
val variance = samples.map(x => math.pow(x.toDouble - mean, 2)).sum / Y | |
val stddev = math.sqrt(variance) | |
val v = stddev / mean | |
percentDeviation = (v * 1000).toInt.toDouble / 10 | |
if (v <= threshold) { | |
ok = false | |
// println("% deviation below threshold: " + v) | |
} | |
else { | |
// println("% deviation too high, increasing trials: " + v) | |
} | |
sample = mean | |
} | |
println("\r - "+label + ": " + formatNanos(sample) + s" ($percentDeviation% deviation, N=$N, K = ${K.toString.take(3)}) ") | |
(label, sample) | |
} | |
def roundToThousands(n: Double) = (n * 1000).toInt / 1000.0 | |
def roundToHundreds(n: Double) = (n * 100).toInt / 100.0 | |
// def formatNanos(nanos: Double) = nanos.toString | |
def formatNanos(nanos: Double) = { | |
if (nanos > 1e9) roundToThousands(nanos/1e9).toString + " seconds" | |
else if (nanos > 1e6) roundToThousands(nanos/1e6).toString + " milliseconds" | |
else if (nanos > 1e3) roundToThousands(nanos/1e3).toString + " microseconds" | |
else nanos.toString + " nanoseconds" | |
} | |
def suite(s: (String,Double)*): Unit = { | |
val tests = s.toList.sortBy(_._2) | |
val minNanos = tests.head._2 | |
// min * x = y | |
// x = y / min | |
tests.foreach { case (label, nanos) => | |
val x = roundToHundreds(nanos / minNanos) | |
println(x.toString.padTo(16, " ").mkString + label) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment