Last active
May 28, 2018 15:23
-
-
Save anthonynsimon/1f840fead7add16bd3765f33de933c98 to your computer and use it in GitHub Desktop.
Simple Interpreter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Program output: | |
// original: Power(Multiply(Variable(x),Sum(Variable(y),Multiply(Constant(5.0),Constant(7.0)))),Power(Sum(Constant(1.0),Constant(2.0)),Power(Sum(Constant(0.0),Constant(1.0)),Constant(2.0)))) | |
// optimized: Power(Multiply(Variable(x),Sum(Variable(y),Constant(35.0))),Constant(3.0)) | |
// show: ((x * (y + 35.0)) ^ 3.0) | |
// evaluates to: 216000.0 for variables: Map(x -> 1.5, y -> 5.0) | |
sealed trait Expr | |
case class Constant(a: Double) extends Expr | |
case class Variable(a: String) extends Expr | |
sealed trait BinaryOp extends Expr | |
case class Sum(a: Expr, b: Expr) extends BinaryOp | |
case class Multiply(a: Expr, b: Expr) extends BinaryOp | |
case class Power(base: Expr, exponent: Expr) extends BinaryOp | |
case class MissingVariableError(message: String) extends Exception(message) | |
def optimize(program: Expr, maxDepth: Int): Expr = { | |
if (maxDepth <= 0) program | |
else { | |
val depth = maxDepth - 1 | |
program match { | |
case Sum(Constant(a), Constant(b)) => Constant(a + b) | |
case Sum(a, b) => optimize(Sum(optimize(a, depth), optimize(b, depth)), depth) | |
case Multiply(Constant(a), Constant(b)) => Constant(a * b) | |
case Multiply(a, b) => optimize(Multiply(optimize(a, depth), optimize(b, depth)), depth) | |
case Power(Constant(a), Constant(b)) => Constant(math.pow(a, b)) | |
case Power(a, b) => optimize(Power(optimize(a, depth), optimize(b, depth)), depth) | |
case x => x | |
} | |
} | |
} | |
def mustResolveVariable(name: String)(implicit variables: Map[String, Double]): Double = | |
variables.getOrElse(name, throw MissingVariableError(s"Missing variable: $name")) | |
def evaluate(program: Expr)(implicit variables: Map[String, Double]): Double = program match { | |
case Constant(a) => a | |
case Variable(a) => mustResolveVariable(a) | |
case Sum(a, b) => evaluate(a) + evaluate(b) | |
case Multiply(a, b) => evaluate(a) * evaluate(b) | |
case Power(a, b) => math.pow(evaluate(a), evaluate(b)) | |
} | |
def show(program: Expr): String = { | |
program match { | |
case Constant(a) => s"${a}" | |
case Variable(a) => s"${a}" | |
case Sum(a, b) => s"(${show(a)} + ${show(b)})" | |
case Multiply(a, b) => s"(${show(a)} * ${show(b)})" | |
case Power(a, b) => s"(${show(a)} ^ ${show(b)})" | |
} | |
} | |
val program = Power( | |
Multiply(Variable("x"), | |
Sum(Variable("y"), | |
Multiply(Constant(5), | |
Constant(7)))), | |
Power(Sum(Constant(1), | |
Constant(2)), | |
Power(Sum(Constant(0), | |
Constant(1)), | |
Constant(2)))) | |
val variables = Map( | |
"x" -> 1.5, | |
"y" -> 5.0 | |
) | |
val optimized = optimize(program, 5) | |
val result = evaluate(optimized)(variables) | |
val pretty = show(optimized) | |
println(s"original: ${program}") | |
println(s"optimized: ${optimized}") | |
println(s"show: ${pretty}") | |
println(s"evaluates to: ${result} for variables: ${variables}") | |
object TerribleDslLib { | |
class TerribleDsl(a: Expr) { | |
def |+(b: Expr): Expr = Sum(a, b) | |
def |*(b: Expr): Expr = Multiply(a, b) | |
def |^(b: Expr): Expr = Power(a, b) | |
} | |
implicit def stringToVariable(name: String): Expr = Variable(name) | |
implicit def doubleToConstant(constant: Double): Expr = Constant(constant) | |
implicit def toTerribleDsl[A](a: String)(implicit ev: A => Expr): TerribleDsl = new TerribleDsl(a) | |
implicit def toTerribleDsl(a: Expr): TerribleDsl = new TerribleDsl(a) | |
} | |
import TerribleDslLib._ | |
val programDsl = "x" |* 7 |^ "y" |+ 5 | |
val optimizedDsl = optimize(programDsl, 5) | |
val resultDsl = evaluate(optimizedDsl)(variables) | |
val prettyDsl = show(optimizedDsl) | |
println(s"original: ${programDsl}") | |
println(s"optimized: ${optimizedDsl}") | |
println(s"show: ${prettyDsl}") | |
println(s"evaluates to: ${resultDsl} for variables: ${variables}") | |
// DSL program output: | |
// original: Sum(Power(Multiply(Variable(x),Constant(7.0)),Variable(y)),Constant(5.0)) | |
// optimized: Sum(Power(Multiply(Variable(x),Constant(7.0)),Variable(y)),Constant(5.0)) | |
// show: (((x * 7.0) ^ y) + 5.0) | |
// evaluates to: 127633.15625 for variables: Map(x -> 1.5, y -> 5.0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment