Last active
May 8, 2016 23:37
-
-
Save steinybot/23411c5adfe86864697c16d73d0dff9c to your computer and use it in GitHub Desktop.
Class for providing variable substitution in a command at runtime.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.language.experimental.macros | |
import scala.reflect.macros.blackbox | |
/** | |
* Macro bundle for case class macros. | |
*/ | |
class CaseClassMacrosImpl(val c: blackbox.Context) { | |
import c.universe._ | |
/** | |
* Creates a new conversion class that provides evidence for converting a case class to a sequence of `String`. | |
* | |
* @tparam T the type of the case class | |
* @return the conversion type | |
*/ | |
def fieldConverter[T: c.WeakTypeTag]: c.Expr[CaseClassToSeq[T]] = { | |
val symbol = c.weakTypeOf[T].typeSymbol | |
val traitSymbol = c.weakTypeOf[CaseClassToSeq[T]].typeSymbol | |
val name = c.freshName(traitSymbol.name.toTypeName) | |
val fields = fieldsOf(c.Expr[T](q"caseClass")) | |
val result = | |
q"""final class $name extends $traitSymbol[$symbol] { | |
override def toSeq(caseClass: $symbol): Seq[String] = $fields | |
} | |
new $name | |
""" | |
c.Expr[CaseClassToSeq[T]](result) | |
} | |
/** | |
* Retrieves the fields of a case class as a sequence in the order that they were defined. | |
* | |
* @tparam T the type of the case class | |
* @return the fields of the case class | |
*/ | |
def fieldsOf[T: c.WeakTypeTag](caseClass: c.Expr[T]): c.Expr[Seq[String]] = { | |
val classType = caseClassOf[T] | |
val getters = caseFieldGetters(classType) | |
val result = withCachedExpr(caseClass) { term => | |
val fieldsAsStrings = getters.map(param => q"$term.$param.toString") | |
q"Seq(..$fieldsAsStrings)" | |
} | |
c.Expr[Seq[String]](result) | |
} | |
/** | |
* Determines the `Type` of the given case class or aborts if it is not a case class. | |
* | |
* @tparam T the case class type tag | |
* @return the weak type of the case class | |
*/ | |
def caseClassOf[T: c.WeakTypeTag]: c.Type = { | |
val classType = c.weakTypeOf[T] | |
val symbol = classType.typeSymbol | |
if (!symbol.isClass || !symbol.asClass.isCaseClass) | |
c.abort(c.enclosingPosition, s"$symbol is not a case class") | |
classType | |
} | |
/** | |
* Finds the getters of each field that is also a parameter in the first parameter list of the primary constructor of | |
* a case class. | |
* | |
* @param classType the type of the case class | |
* @return the terms in the same order that they were declared | |
*/ | |
def caseFieldGetters(classType: c.Type): List[TermSymbol] = { | |
classType.decls.sorted.filter(_.isTerm).map(_.asTerm). | |
filter(term => term.isCaseAccessor && term.isGetter) | |
} | |
/** | |
* Assigns the result of evaluating an expression to a value and then provides the term of that value to a | |
* function that can then reuse that cached value it constructing another tree. | |
* | |
* This is part of good macro hygiene. | |
* | |
* @param expr the expression to evaluate | |
* @param other the function that uses the value | |
* @tparam T the type of the expression | |
* @return a block containing the value and result of the function | |
*/ | |
def withCachedExpr[T](expr: c.Expr[T])(other: (TermName) => c.Tree): c.Tree = { | |
val resolvedTerm = TermName(c.freshName) | |
val resolvedExpr = q"val $resolvedTerm = $expr" | |
val exprs = List(resolvedExpr, other(resolvedTerm)) | |
q"..$exprs" | |
} | |
} | |
/** | |
* Macros for working with case classes. | |
*/ | |
object CaseClassMacros { | |
/** | |
* Creates a sequence of the fields of a case class. | |
* | |
* The fields that are returned are the parameters in the first parameter list of the primary constructor. | |
* | |
* Each field is access via its getter and is converted into a `String` via its `toString` method. | |
* | |
* @param caseClass the case class instance | |
* @tparam T the type of the case class | |
* @return a sequence of the fields in the order that they were declared | |
*/ | |
def fieldsOf[T](caseClass: T): Seq[String] = macro CaseClassMacrosImpl.fieldsOf[T] | |
} | |
/** | |
* Converts a case class to a sequence of `String`. | |
* | |
* This uses a technique known as | |
* <a href="http://docs.scala-lang.org/overviews/macros/implicits#fundep-materialization">Fundep materialization</a> | |
* and is useful when combined with Type Class Pattern and/or | |
* <a href="http://docs.scala-lang.org/tutorials/FAQ/finding-implicits.html#context-bounds">Context Bounds</a>. | |
*/ | |
object CaseClassToSeq { | |
/** | |
* Provides an implicit conversion from a case class to a sequence of `String`. | |
* | |
* This should be used to provide evidence in Context Bounds. | |
* | |
* @tparam T the type of the case class | |
* @return the conversion type | |
*/ | |
implicit def materializeCaseClassToSeq[T]: CaseClassToSeq[T] = macro CaseClassMacrosImpl.fieldConverter[T] | |
} | |
/** | |
* Represents a conversion from a case class to a sequence of `String`. | |
* | |
* @tparam T the type of the case class | |
*/ | |
trait CaseClassToSeq[T] { | |
/** | |
* Converts the instance to a sequence of `String`. | |
* | |
* @param caseClass the case class instance | |
* @return the case class as a sequence | |
*/ | |
def toSeq(caseClass: T): Seq[String] | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
copy { | |
# This is the command for copying a file. | |
# If using the Vagrant VM the settings can be obtained from ssh-config. | |
command = ["""C:\Program Files (x86)\PuTTY\pscp.exe""", "-P", "2222", | |
"-i", ".vagrant/machines/default/virtualbox/private_key.ppk", "-batch", | |
"-hostkey", ${agent.test.remote.hostkey}, | |
"-l", "vagrant", "$source", "127.0.0.1:/home/vagrant/$target"] | |
# Variables in the command. | |
variables { | |
# The local path of the file to be copied. | |
source = "$source" | |
# The target path of the file on the remote. | |
target = "$target" | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/** | |
* A command for copying files from the local machine to a remote machine. | |
* | |
* @param command the command to be executed | |
* @param variables the variable names in the command | |
*/ | |
case class CopyCommand(override val command: Seq[String], variables: CopyVariables) | |
extends SafeVariableCommand(command, variables) { | |
} | |
/** | |
* The variables for copying. | |
* | |
* @param source the source variable | |
* @param target the target variable | |
*/ | |
case class CopyVariables(source: String, target: String) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
val configs = ConfigFactory.load("command") | |
val copyCmd = configs.get[CopyCommand]("copy").value | |
val variables = CopyVariables("""C:\Users\me\test.txt""", "~/test.txt") | |
val cmd = copyCmd.substitute(variables) | |
cmd ! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.annotation.tailrec | |
/** | |
* This represents a command which contains named variables to be substituted at runtime. | |
* | |
* The command is a sequence of strings where each string is a command or argument. Each command or argument may | |
* contain zero, one or more variables. | |
* | |
* Variable substitution works by replacing a variable name with another value. There is no special escaping since the | |
* caller is free to choose whatever name of the variable that is guaranteed to be unambiguous. | |
* | |
* Determining where variables are and what they need to be replaced with is done when the command is created. | |
* | |
* This is a generic implementation so the number of variables permitted is not fixed. However this means that the | |
* caller is responsible for providing the correct number of values for substitution. | |
* | |
* Initialisation is worst case the number of commands multiplied by the number of variable names. However this is | |
* intended to be created once and used multiple times so that cost is amortised. Substitution is linear time, | |
* proportional to the number of commands and arguments plus the number of variables. (At least this is the idea) | |
* | |
* This is similar to Scala's string interpolation and [[StringContext]] however it works for variable strings as | |
* opposed to string literals. | |
* | |
* @tparam T type of variables that are accepted | |
*/ | |
trait VariableCommand[T] { | |
/** | |
* The command that contains variables. | |
*/ | |
val command: Seq[String] | |
/** | |
* The names of each variable to be replaced. | |
*/ | |
val names: T | |
implicit def toSeq(variables: T): Seq[String] | |
private type VariableSubstitution = (Seq[String]) => String | |
private val namesList: List[String] = names.toList | |
private val substitutions: Seq[VariableSubstitution] = createSubstitutions | |
/** | |
* Substitutes the variable names in the command with the values. | |
* | |
* The order of the values must match the order of the variable names. | |
* | |
* @param values the values to use as substitutions | |
* @return the command with the substituted values | |
* @throws IllegalArgumentException if the number of values does not match the number of variable names | |
*/ | |
def substitute(values: T): Seq[String] = substitutions.map(build => build(values)) | |
private def createSubstitutions: Seq[VariableSubstitution] = command.map(createSubstitution) | |
private def createSubstitution(arg: String): VariableSubstitution = { | |
// Recursively go through each position in the argument and check to see if there is a variable at that position. | |
// If there is then accumulate the part between where the last match ended and where the current match starts, | |
// followed by the index of the variable (to be used for substitution with the values later). Then continue | |
// searching from where the variable ends. | |
// Once we reach the end then accumulate any remaining characters and reverse (since we accumulated by prepending to | |
// the list) and then convert it to a substitution rule. | |
@tailrec | |
def loop(currentPos: Int, lastMatchEnd: Int, accum: List[Either[String, Int]]): VariableSubstitution = { | |
if (currentPos < arg.length) { | |
checkVariable(currentPos, arg) match { | |
case Some(result) => | |
val (nextPos, varIndex) = result | |
val nextAccum = if (currentPos > lastMatchEnd) { | |
val beforeVar = arg.substring(lastMatchEnd, currentPos) | |
Right(varIndex) :: Left(beforeVar) :: accum | |
} else { | |
Right(varIndex) :: accum | |
} | |
loop(nextPos, nextPos, nextAccum) | |
case None => loop(currentPos + 1, lastMatchEnd, accum) | |
} | |
} else { | |
val finalAccum = if (lastMatchEnd < arg.length) Left(arg.substring(lastMatchEnd)) :: accum else accum | |
mergeChoices(finalAccum.reverse) | |
} | |
} | |
loop(0, 0, Nil) | |
} | |
private def checkVariable(startPos: Int, arg: String): Option[(Int, Int)] = { | |
// Recursively check each position from the start until we find: | |
// - a complete match, or | |
// - there are no more names left that could be a match, or | |
// - we have gone past the end of the argument | |
// If a match is found then return both the position after the match (which becomes the next position to search | |
// from) and also the index of the variable (this is used for substitution later on). | |
@tailrec | |
def loop(currentPos: Int, depth: Int, possibleNames: List[(String, Int)]): Option[(Int, Int)] = { | |
// If the depth is the length of the term then we have already matched every character. | |
possibleNames.find(_._1.length == depth) match { | |
case Some(term) => Some(currentPos, term._2) | |
case None => | |
if (possibleNames.isEmpty || currentPos >= arg.length) None | |
else { | |
val c = arg.charAt(currentPos) | |
val matching = possibleNames.filter(_._1.charAt(depth) == c) | |
loop(currentPos + 1, depth + 1, matching) | |
} | |
} | |
} | |
loop(startPos, 0, namesList.zipWithIndex) | |
} | |
private def mergeChoices(choices: List[Either[String, Int]]): VariableSubstitution = { | |
// Now we have all the "choices" where a choice is either the part of the argument to be copied verbatim (the | |
// left) or the index of the variable to be substituted (the right). | |
// Create a function which, given the variable values (that have the same indicies as the names), will apply each | |
// choice to build up the resulting argument. | |
@tailrec | |
def loop(choices: List[Either[String, Int]], builder: StringBuilder)(values: Seq[String]): String = { | |
choices match { | |
case head :: tail => | |
head match { | |
case Left(str) => builder.append(str) | |
case Right(index) => builder.append(values(index)) | |
} | |
loop(tail, builder)(values) | |
case Nil => builder.toString | |
} | |
} | |
// This actually starts the loop. We need to ensure that each function creates its own builder. | |
def start(values: Seq[String]): String = loop(choices, new StringBuilder)(values) | |
// Slight optimisation (both now and for substitutions) for when there are no substitutions. | |
// Be careful with comparing the length of the list since its worst case complexity is the length of the list. | |
if (choices.lengthCompare(1) == 0 && choices.head.isLeft) _ => choices.head.left.get | |
else start _ | |
} | |
} | |
/** | |
* An extension to [[VariableCommand]] which does not do any compile time checking of variables. | |
*/ | |
class UnsafeVariableCommand(val command: Seq[String], val names: String*) extends VariableCommand[Seq[String]] { | |
override implicit def toSeq(variables: Seq[String]): Seq[String] = { | |
require(variables.length == names.length, s"The number of variables (${variables.length}) must be equal to the " + | |
s"number of variable names (${names.length})") | |
variables | |
} | |
} | |
/** | |
* An extension to [[VariableCommand]] which provides compile time checking of variables. | |
* | |
* @tparam T the type of variables | |
*/ | |
class SafeVariableCommand[T: CaseClassToSeq](val command: Seq[String], val names: T) extends VariableCommand[T] { | |
// Be careful with initialisation order in here. | |
// The VariableCommand constructor needs to use toSeq so the fields it uses need to be initialised lazily. | |
// Conjure up the implicit converter from the context bounds. | |
private lazy val converter = implicitly[CaseClassToSeq[T]] | |
protected lazy val nameSeq = converter.toSeq(names) | |
override implicit def toSeq(variables: T): Seq[String] = converter.toSeq(variables) | |
} |
Modified to include a type safe version and an example of how it can be used.
The safe version is truly safe now. No need to implement a toSeq method. Macros to the rescue!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I haven't tested this at any great length at all so if there are any bugs or improvements then please feel free to comment.