Skip to content

Instantly share code, notes, and snippets.

@kevinwright
Created March 21, 2013 15:29
Show Gist options
  • Select an option

  • Save kevinwright/5213925 to your computer and use it in GitHub Desktop.

Select an option

Save kevinwright/5213925 to your computer and use it in GitHub Desktop.
Universal lifting macro - courtesy of Dmitry Grigoriev
object ObjectToExpr {
def objectToExpr[O <: AnyRef](c1: Context, o: O): c1.Expr[O] = macro objectToExprMacro[O]
def objectToExprMacro[O <: AnyRef : c.WeakTypeTag](c: Context)(c1: c.Expr[Context], o: c.Expr[O]) = c.Expr[Nothing] {
import c.universe._
// println("*** objectToExprMacro() started")
val mh = new MacroHelper[c.type](c)
def isConstant(t: Type) = t =:= typeOf[Boolean] || t =:= typeOf[String] || t =:= typeOf[Int]
def isNonAbstractCaseClass(t: Type) = {
val s = t.typeSymbol
s.isClass && !s.isModuleClass && !s.asClass.isAbstractClass && s.asClass.isCaseClass
}
// First we recursively collect all types we'll have to handle.
// Then we'll generate a recursive method with flat match that handles them.
val typesToHandle = {
val buf = new mutable.HashSet[Type]()
// This is against infinit recursion when handling <refinement>.
val handledPolymorphic = new mutable.HashSet[Symbol]()
def recursive(t: Type) {
if (buf.contains(t) || handledPolymorphic.contains(t.typeSymbol)) { return }
if (isConstant(t)) { buf += t; return }
if (t =:= typeOf[xml.NodeSeq]) { buf += t; return }
if (
t <:< c.weakTypeOf[Option[_]] ||
t <:< c.weakTypeOf[List[_]] ||
t <:< c.weakTypeOf[Map[_, _]]
) {
buf += t
// Recursively collect collection item types (including map keys).
val typeParams = t match { case tr: TypeRef => tr.args }
for (t2 <- typeParams) { recursive(t2) }
return
}
if (isNonAbstractCaseClass(t)) {
buf += t
// Recursively collect all case class's field types.
val pp = t.declaration(nme.CONSTRUCTOR).asMethod.paramss
assert(pp.size == 1, "Type " + t.typeSymbol.fullName + " must have one parameter list")
pp.head.foreach(p => recursive(p.typeSignature))
return
}
// Don't add abstract polymorphic class to typesToHandle, but only its subclasses.
val subtypes_? = mh.findAndEvaluatePolymorphicAnnotation(t.typeSymbol)
if (!subtypes_?.isEmpty) {
// Recursively collect polymorphic class's subtypes.
handledPolymorphic += t.typeSymbol
subtypes_?.get.foreach(recursive)
return
}
if (t.typeSymbol.isModuleClass) {
buf += t
return
}
// If we pass List(A1, A2) to macro, where A1, A2 - case objects with common @Polymorphic-
// abstract class A, then we get List[A with Product with Serializable] and we'll fail with
// "Unsupported type blablabla.<refinement>".
// To handle, let's try walk up to their closest @Polymorphic ancestor.
if (t.typeSymbol.isClass) {
val poly_? = t.typeSymbol.asClass.baseClasses.find(s0 => !mh.findAndEvaluatePolymorphicAnnotation(s0).isEmpty)
if (!poly_?.isEmpty) {
handledPolymorphic += poly_?.get
mh.findAndEvaluatePolymorphicAnnotation(poly_?.get).get.foreach(recursive)
return
}
}
throw new Exception("Unsupported type " + t.typeSymbol.fullName)
}
recursive(weakTypeOf[O])
buf.map(_.erasure).toList.sortBy(_.typeSymbol.fullName)
}
// println("*** objectToExprMacro() typesToHandle = " + typesToHandle.map(_.typeSymbol.name.decoded).sorted)
def _Apply(fun: Tree, args: Tree) = Apply(Select(Ident("Apply"), "apply"), List(fun, args))
def _Constant(value: Tree) = Apply(Select(Ident("Constant"), "apply"), List(value))
def _Ident(name: Tree) = Apply(Select(Ident("Ident"), "apply"), List(name))
def _List(xs: Tree*) = Apply(Select(Ident("List"), "apply"), xs.toList)
def _Literal(value: Tree) = Apply(Select(Ident("Literal"), "apply"), List(value))
def _Select(qualifier: Tree, name: Tree) = Apply(Select(Ident("Select"), "apply"), List(qualifier, name))
def _newTermName(name: String) = Apply(Ident(newTermName("newTermName")), List(Literal(Constant(name))))
// Extracted val, otherwise IDEA mistakenly highlights the whole thing as erroneous.
val anyrefTrees: List[Tree] = reify {
val c2 = c1.splice
import c2.universe._
// Support multiple references to the same instance.
// Does NOT store Map.empty() because type inference is lame: if we write val xN = Map.apply() and then pass
// this xN as argument, we'll get compiler error casting Map[Nothing, Nothing] to Map[..., ...].
// 1st value in tuple - instance. We have List[Tuple4[AnyRef,...]], not Map[AnyRef,Tuple3[...]],
// because we must compare by identity, not equality.
// 2nd value - left for sorting.
// 3rd value - tree to return: Ident("xN")
// 4th value - tree that gives "val xN = ....".
val handledInstances = new collection.mutable.ListBuffer[(AnyRef, Int, Tree, Tree)]
def handleInstance[T <: AnyRef](x: T, gen: => Tree): Tree = {
val found_? = handledInstances.find(_._1 eq x) // search by identity, not equality
if (!found_?.isEmpty) {
val ret = found_?.get._3
ret
} else {
// First we call generator which performs recursive processing (and fills handledInstances),
// and only after that we read handledInstances.size.
val genDone = gen
val newIndex = handledInstances.size
// If genDone contains Map.apply() (empty Map), don't generate "val xN".
genDone match {
case Apply(Select(Ident(_map), _apply), Nil)
if _map.toString == "Map" && _apply.toString == "apply" => { genDone }
case _ => {
val ret = Ident("x" + newIndex)
handledInstances += ((x, newIndex, ret, ValDef(Modifiers(), newTermName("x" + newIndex), TypeTree(), genDone)))
ret
}
}
}
}
// Let's put a bit more inside reifty:
def getResultTree() = {
assert(handledInstances.size > 0, "handledInstances.size > 0")
Block(
handledInstances.toList.sortBy(_._2).map(_._4),
Ident("x" + (handledInstances.size - 1))
)
}
null // last expression to ignore
}.tree match {
case Block(stats, expr) => stats // get block contents, need these visible in recursive().
}
// Introduce variable agains buggy IDEA code hilight.
val callRecursive: Tree = Apply(Ident("recursive"), List(o.tree))
def callHandle(gen: Tree) = Apply(Ident("handleInstance"), List(Ident("x"), gen))
Block(
anyrefTrees :::
List(
DefDef(
Modifiers(),
newTermName("recursive"),
List(),
List(List(ValDef(Modifiers(Flag.PARAM), newTermName("o"), TypeTree(typeOf[Any]), EmptyTree))),
Ident(newTypeName("Tree")),
Match(
Ident("o"),
typesToHandle.flatMap(t => {
val bindX = Bind(newTermName("x"), Typed(Ident("_"), TypeTree(t)))
def caseDef(tr: Tree) = CaseDef(bindX, EmptyTree, tr)
if (isConstant(t)) {
caseDef(_Literal(_Constant(Ident("x")))) :: Nil
} else if (t <:< typeOf[xml.NodeSeq]) {
CaseDef(
Bind(newTermName("x"), Typed(Ident("_"), TypeTree(typeOf[xml.Text]))),
EmptyTree,
callHandle(_Apply(
_Select(_Ident(_newTermName("xml")), _newTermName("Text")),
_List(_Literal(_Constant(Select(Ident("x"), "data"))))
))
) ::
CaseDef(
Bind(newTermName("x"), Typed(Ident("_"), TypeTree(typeOf[xml.Elem]))),
EmptyTree,
// генерируем: Apply(
// Select(Ident(xml), Elem),
// List(null, label, xml.Null, xml.TopScpe, minimizeEmpty) ::: x.child.toList.map(v => recursive(v))
// )
// Because ::: is right-associative, we apply it to its 2nd operand.
callHandle(_Apply(
_Select(_Ident(_newTermName("xml")), _newTermName("Elem")),
Apply(
Select(
Apply(
Select(Select(Select(Ident("x"), "child"), "toList"), "map"),
List(Function(
List(ValDef(Modifiers(Flag.PARAM), newTermName("v"), TypeTree(), EmptyTree)),
Apply(Ident("recursive"), List(Ident("v")))
))
),
"$colon$colon$colon"
),
List(_List(
_Literal(_Constant(Literal(Constant(null)))),
_Literal(_Constant(Select(Ident("x"), "label"))),
_Select(_Ident(_newTermName("xml")), _newTermName("Null")),
_Select(_Ident(_newTermName("xml")), _newTermName("TopScope")),
_Literal(_Constant(Select(Ident("x"), "minimizeEmpty")))
))
)
))
) ::
CaseDef(
Bind(newTermName("x"), Typed(Ident("_"), TypeTree(typeOf[xml.NodeSeq]))),
EmptyTree,
// multiple siblings on top level, generate:
// Apply(
// Select(Select(Ident("xml"), "NodeSeq"), "fromSeq"),
// List(Apply(
// Select(Ident("Seq", "apply")),
// x.toList.map(v => recursive(v))
// ))
// )
callHandle(_Apply(
_Select(_Select(_Ident(_newTermName("xml")), _newTermName("NodeSeq")), _newTermName("fromSeq")),
_List(_Apply(
_Select(_Ident(_newTermName("Seq")), _newTermName("apply")),
Apply(
Select(Select(Ident("x"), "toList"), "map"),
List(Function(
List(ValDef(Modifiers(Flag.PARAM), newTermName("v"), TypeTree(), EmptyTree)),
Apply(Ident("recursive"), List(Ident("v")))
))
)
))
))
) :: Nil
} else if (t <:< weakTypeOf[Option[_]]) {
CaseDef(
Ident("None"),
EmptyTree,
_Ident(_newTermName("None"))
) ::
CaseDef(
Bind(newTermName("x"), Typed(Ident("_"), TypeTree(weakTypeOf[Some[_]]))),
EmptyTree,
callHandle(_Apply(
_Select(_Ident(_newTermName("Some")), _newTermName("apply")),
_List(Apply(Ident("recursive"), List(Select(Ident("x"), "get"))))
))
) :: Nil
} else if (t =:= weakTypeOf[List[_]]) {
// generate: Apply(Select(Ident("List"), "apply"), x.map(v => recursive(v)))
caseDef(
callHandle(_Apply(
_Select(_Ident(_newTermName("List")), _newTermName("apply")),
Apply(Select(Ident("x"), "map"), List(
Function(
List(ValDef(Modifiers(Flag.PARAM), newTermName("v"), TypeTree(), EmptyTree)),
Apply(Ident("recursive"), List(Ident("v")))
)
))
))
) :: Nil
} else if (t =:= weakTypeOf[Map[_, _]]) {
// generate: Apply(Select(Ident("Map"), "apply"), x.map(v => Tuple2(recursive(v._1), recursive(v._2)).toList)
caseDef(
callHandle(_Apply(
_Select(_Ident(_newTermName("Map")), _newTermName("apply")),
Apply(Select(Select(Ident("x"), "toList"), "map"), List(
Function(
List(ValDef(Modifiers(Flag.PARAM), newTermName("v"), TypeTree(), EmptyTree)),
_Apply(
_Select(_Ident(_newTermName("Tuple2")), _newTermName("apply")),
_List(
Apply(Ident("recursive"), List(Select(Ident("v"), "_1"))),
Apply(Ident("recursive"), List(Select(Ident("v"), "_2")))
)
)
)
))
))
) :: Nil
} else if (isNonAbstractCaseClass(t)) {
// For non-abstract case classes,
// generate: { New(TypeTree(t1), List(List(recursive(x.arg1), recursive(x.arg2), ...))) }
val s = t.typeSymbol.asClass
val pp = t.declaration(nme.CONSTRUCTOR).asMethod.paramss
assert(pp.size == 1, "Class must have one parameter list: " + s.fullName)
caseDef(
callHandle(Apply(Ident("New"), List(
Apply(Ident("TypeTree"), List(TypeApply(Ident("typeOf"), List(Ident(t.typeSymbol))))),
_List(_List(pp.head.map(p =>
Apply(Ident("recursive"), List(Select(Ident("x"), p.name)))
): _*))
)))
) :: Nil
} else if (t.typeSymbol.isModuleClass) {
val s = t.typeSymbol
CaseDef(
mh.treeDottedName(s.fullName),
EmptyTree, {
// Аналог mh.treeDottedName, но с префиксами '_'.
def impl(ss: List[String]): Tree = {
if (ss.size == 1) _Ident(_newTermName(ss.head))
else _Select(impl(ss.tail), _newTermName(ss.head))
}
impl(s.fullName.split('.').toList.reverse)
}
) :: Nil
} else {
throw new Exception("Unsupported type: " + t.typeSymbol.fullName)
}
}) :::
CaseDef(
Bind(newTermName("x"), Ident("_")),
EmptyTree,
mh.treeThrow(mh.treeConcatStrings(
Literal(Constant("Unsupported type: ")),
Apply(Select(Apply(Select(Ident("x"), "getClass"), Nil), "getCanonicalName"), Nil)
))
) :: Nil
)
),
callRecursive
),
Apply(
TypeApply(Select(Ident("c"), "Expr"), List(TypeTree(c.weakTypeOf[O]))),
// Have to generate asInstanceOf, otherwise it cannot cast c2.Block to c.Tree.
List(TypeApply(
Select(Apply(Ident("getResultTree"), List()), "asInstanceOf"),
List(Select(Select(Ident(newTermName("c")), newTermName("universe")), newTypeName("Tree")))
))
)
)
}
}
@dimgel
Copy link

dimgel commented Mar 21, 2013

Original discussion (with scaladoc & usage example) is here: https://groups.google.com/forum/#!topic/scala-user/BSHLDglW0OE

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment