Skip to content

Instantly share code, notes, and snippets.

@akshaal
Created August 18, 2012 17:59
Show Gist options
  • Save akshaal/3388753 to your computer and use it in GitHub Desktop.
Save akshaal/3388753 to your computer and use it in GitHub Desktop.
Scala 2.10: annotated fields macro
/** Akshaal, 2012. http://akshaal.info */
import language.experimental.macros
import scala.reflect.macros.Context
import scala.annotation.Annotation
/**
* Macros for traversing over annotated elements.
*
* See http://www.akshaal.info/2012/08/scala-210-annotated-fields-macro.html for more information.
*/
object annotated {
/**
* An object of this class represents an annotated field.
* @tparam I type of class the field belongs to
* @tparam A type of annotation arguments (TupleX or None)
* @param name name of the field
* @param get function that returns field value of an instance given as argument to the function
* @param args list of arguments to the annotation found on the field
*/
case class Field[I <: AnyRef, A <: Product](name : String, get : I => Any, args : A)
/**
* List of fields belonging to the given type.
* @tparam I Owner of fields
* @tparam A type of annotation arguments (TupleX or None)
*/
type Fields[I <: AnyRef, A <: Product] = List[Field[I, A]]
/**
* Macro which inspects class 'I' and returns a list of fields annotated with annotation 'Ann'.
* @tparam Ann search for field with this annotation
* @tparam Args type of arguments in the annotation (TupleX or None)
* @tparam I type of class to scan for annotated fields
*/
def fields[Ann <: Annotation, Args <: Product, I <: AnyRef] = macro fieldsImpl[Ann, Args, I]
/**
* Implementation of the fields macro.
*/
def fieldsImpl[AnnTT <: Annotation : c.AbsTypeTag,
Args <: Product : c.AbsTypeTag,
ITT <: AnyRef : c.AbsTypeTag](c : Context) : c.Expr[Fields[ITT, Args]] = {
import c.universe._
// Materialize types
val instanceT = implicitly[c.AbsTypeTag[ITT]].tpe
val annT = implicitly[c.AbsTypeTag[AnnTT]].tpe
// Get annotated fields. Note that hasAnnotation doesn't work for a reason...
val annSymbol = annT.typeSymbol
val fields = instanceT.members filter (member => member.getAnnotations.exists(_.atp == annT))
// Fold given expression sequence into a new expression that creates List of expressions at runtime
def foldIntoListExpr[T : c.AbsTypeTag](exprs : Iterable[c.Expr[T]]) : c.Expr[List[T]] =
exprs.foldLeft(reify { Nil : List[T] }) {
(accumExpr, expr) =>
reify { expr.splice :: accumExpr.splice }
}
// For each field, construct expression that will instantiate Field object at runtime
val fieldExprs =
for (field <- fields) yield {
val argTrees = field.getAnnotations.find(_.atp == annT).get.args
val name = field.name.toString.trim // Why is there a space at the end of field name?!
val nameExpr = c literal name
// Construct arguments list expression
val argsExpr =
if (argTrees.isEmpty) {
c.Expr [Args] (Select(Ident(newTermName("scala")), newTermName("None")))
} else {
val tupleConstTree = Select(Select(Ident(newTermName ("scala")),
newTermName(s"Tuple${argTrees.size}")),
newTermName("apply"))
c.Expr [Args] (Apply (tupleConstTree, argTrees))
}
// Construct expression (x : $I) => x.$name
val getFunArgTree = ValDef(Modifiers(), newTermName("x"), TypeTree(instanceT), EmptyTree)
val getFunBodyTree = Typed(Select(Ident(newTermName("x")), newTermName(name)),
(SelectFromTypeTree(Ident(newTermName("scala")), newTypeName("Any"))))
val getFunExpr = c.Expr[ITT => Any](Function(List(getFunArgTree), getFunBodyTree))
reify {
Field[ITT, Args](name = nameExpr.splice, get = getFunExpr.splice, args = argsExpr.splice)
}
}
// Construct expression list like field1 :: field2 :: Field3 ... :: Nil
foldIntoListExpr(fieldExprs)
}
}
// LocalWords: args TupleX Tuple argTrees
/** Akshaal, 2012. http://akshaal.info */
import org.specs2._
class AnnotatedSpec extends Specification with matcher.ScalaCheckMatchers {
def is =
"This is a specification for macros in the application." ^
"fields macro should work" ! (recordTest.example1 and personTest.example) ^
"fields macro should work with arbitrary data" ! recordTest.example2
object personTest {
type FormatFun = Any => Any
type PrettyArgs = (Option[String], FormatFun)
class Pretty(aka : Option[String] = None, format : FormatFun = identity) extends annotation.StaticAnnotation
def pp[X <: AnyRef](fields : annotated.Fields[X, PrettyArgs])(x : X) = {
fields map {
case annotated.Field(fieldName, get, (akaOpt, fmtFun)) =>
val name = fieldName.replaceAll("([A-Z][a-z]+)", " $1").toLowerCase.capitalize
val aka = akaOpt map (" (aka " + _ + ")") getOrElse ""
val value = fmtFun(get(x))
s"$name$aka: $value"
} mkString "\n"
}
case class Person(
id : Int,
@Pretty(aka = Some("nickname")) name : String,
@Pretty firstName : String,
@Pretty(None, format = _.toString.toUpperCase) secondName : String,
@Pretty(None, format = { case x : Option[_] => x getOrElse "" }) twitter : Option[String])
val personPrettyFields = annotated.fields[Pretty, PrettyArgs, Person]
val ppPerson = pp(personPrettyFields) _
val person1 = Person(1, "akshaal", "Evgeny", "Chukreev", Some("https://twitter.com/Akshaal"))
val person2 = Person(2, "BillGates", "Bill", "Gates", Some("https://twitter.com/BillGates"))
val persons = List(person1, person2)
def example =
(persons map ppPerson) must_==
List(
List(
"Name (aka nickname): akshaal",
"First name: Evgeny",
"Second name: CHUKREEV",
"Twitter: https://twitter.com/Akshaal").mkString ("\n"),
List(
"Name (aka nickname): BillGates",
"First name: Bill",
"Second name: GATES",
"Twitter: https://twitter.com/BillGates").mkString ("\n"))
}
object recordTest {
class Attr(title : String, priority : Int = 0) extends annotation.StaticAnnotation
case class Record(val id : Int,
@Attr("Name", 1) val name : String,
@Attr("Weight") val weight : Long)
val annotatedRecordFields = annotated.fields[Attr, (String, Int), Record]
def fieldInfo2Str(record : Record)(field : annotated.Field[Record, (String, Int)]) : String = {
val value = field get record
val (title, priority) = field.args
s"${field.name}: ${value} ($title, $priority)"
}
def example1 = {
val record = Record(18, "abc", 4)
annotatedRecordFields.map(fieldInfo2Str(record)).toSet must_== Set(
"name: abc (Name, 1)",
"weight: 4 (Weight, 0)"
)
}
def example2 = check { (l : Long, s : String) =>
val record = Record(18, s, l)
annotatedRecordFields.map(fieldInfo2Str(record)).toSet must_== Set(
s"name: $s (Name, 1)",
s"weight: $l (Weight, 0)"
)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment