Created
March 21, 2013 15:29
-
-
Save kevinwright/5213925 to your computer and use it in GitHub Desktop.
Universal lifting macro - courtesy of Dmitry Grigoriev
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| object ObjectToExpr { | |
| def objectToExpr[O <: AnyRef](c1: Context, o: O): c1.Expr[O] = macro objectToExprMacro[O] | |
| def objectToExprMacro[O <: AnyRef : c.WeakTypeTag](c: Context)(c1: c.Expr[Context], o: c.Expr[O]) = c.Expr[Nothing] { | |
| import c.universe._ | |
| // println("*** objectToExprMacro() started") | |
| val mh = new MacroHelper[c.type](c) | |
| def isConstant(t: Type) = t =:= typeOf[Boolean] || t =:= typeOf[String] || t =:= typeOf[Int] | |
| def isNonAbstractCaseClass(t: Type) = { | |
| val s = t.typeSymbol | |
| s.isClass && !s.isModuleClass && !s.asClass.isAbstractClass && s.asClass.isCaseClass | |
| } | |
| // First we recursively collect all types we'll have to handle. | |
| // Then we'll generate a recursive method with flat match that handles them. | |
| val typesToHandle = { | |
| val buf = new mutable.HashSet[Type]() | |
| // This is against infinit recursion when handling <refinement>. | |
| val handledPolymorphic = new mutable.HashSet[Symbol]() | |
| def recursive(t: Type) { | |
| if (buf.contains(t) || handledPolymorphic.contains(t.typeSymbol)) { return } | |
| if (isConstant(t)) { buf += t; return } | |
| if (t =:= typeOf[xml.NodeSeq]) { buf += t; return } | |
| if ( | |
| t <:< c.weakTypeOf[Option[_]] || | |
| t <:< c.weakTypeOf[List[_]] || | |
| t <:< c.weakTypeOf[Map[_, _]] | |
| ) { | |
| buf += t | |
| // Recursively collect collection item types (including map keys). | |
| val typeParams = t match { case tr: TypeRef => tr.args } | |
| for (t2 <- typeParams) { recursive(t2) } | |
| return | |
| } | |
| if (isNonAbstractCaseClass(t)) { | |
| buf += t | |
| // Recursively collect all case class's field types. | |
| val pp = t.declaration(nme.CONSTRUCTOR).asMethod.paramss | |
| assert(pp.size == 1, "Type " + t.typeSymbol.fullName + " must have one parameter list") | |
| pp.head.foreach(p => recursive(p.typeSignature)) | |
| return | |
| } | |
| // Don't add abstract polymorphic class to typesToHandle, but only its subclasses. | |
| val subtypes_? = mh.findAndEvaluatePolymorphicAnnotation(t.typeSymbol) | |
| if (!subtypes_?.isEmpty) { | |
| // Recursively collect polymorphic class's subtypes. | |
| handledPolymorphic += t.typeSymbol | |
| subtypes_?.get.foreach(recursive) | |
| return | |
| } | |
| if (t.typeSymbol.isModuleClass) { | |
| buf += t | |
| return | |
| } | |
| // If we pass List(A1, A2) to macro, where A1, A2 - case objects with common @Polymorphic- | |
| // abstract class A, then we get List[A with Product with Serializable] and we'll fail with | |
| // "Unsupported type blablabla.<refinement>". | |
| // To handle, let's try walk up to their closest @Polymorphic ancestor. | |
| if (t.typeSymbol.isClass) { | |
| val poly_? = t.typeSymbol.asClass.baseClasses.find(s0 => !mh.findAndEvaluatePolymorphicAnnotation(s0).isEmpty) | |
| if (!poly_?.isEmpty) { | |
| handledPolymorphic += poly_?.get | |
| mh.findAndEvaluatePolymorphicAnnotation(poly_?.get).get.foreach(recursive) | |
| return | |
| } | |
| } | |
| throw new Exception("Unsupported type " + t.typeSymbol.fullName) | |
| } | |
| recursive(weakTypeOf[O]) | |
| buf.map(_.erasure).toList.sortBy(_.typeSymbol.fullName) | |
| } | |
| // println("*** objectToExprMacro() typesToHandle = " + typesToHandle.map(_.typeSymbol.name.decoded).sorted) | |
| def _Apply(fun: Tree, args: Tree) = Apply(Select(Ident("Apply"), "apply"), List(fun, args)) | |
| def _Constant(value: Tree) = Apply(Select(Ident("Constant"), "apply"), List(value)) | |
| def _Ident(name: Tree) = Apply(Select(Ident("Ident"), "apply"), List(name)) | |
| def _List(xs: Tree*) = Apply(Select(Ident("List"), "apply"), xs.toList) | |
| def _Literal(value: Tree) = Apply(Select(Ident("Literal"), "apply"), List(value)) | |
| def _Select(qualifier: Tree, name: Tree) = Apply(Select(Ident("Select"), "apply"), List(qualifier, name)) | |
| def _newTermName(name: String) = Apply(Ident(newTermName("newTermName")), List(Literal(Constant(name)))) | |
| // Extracted val, otherwise IDEA mistakenly highlights the whole thing as erroneous. | |
| val anyrefTrees: List[Tree] = reify { | |
| val c2 = c1.splice | |
| import c2.universe._ | |
| // Support multiple references to the same instance. | |
| // Does NOT store Map.empty() because type inference is lame: if we write val xN = Map.apply() and then pass | |
| // this xN as argument, we'll get compiler error casting Map[Nothing, Nothing] to Map[..., ...]. | |
| // 1st value in tuple - instance. We have List[Tuple4[AnyRef,...]], not Map[AnyRef,Tuple3[...]], | |
| // because we must compare by identity, not equality. | |
| // 2nd value - left for sorting. | |
| // 3rd value - tree to return: Ident("xN") | |
| // 4th value - tree that gives "val xN = ....". | |
| val handledInstances = new collection.mutable.ListBuffer[(AnyRef, Int, Tree, Tree)] | |
| def handleInstance[T <: AnyRef](x: T, gen: => Tree): Tree = { | |
| val found_? = handledInstances.find(_._1 eq x) // search by identity, not equality | |
| if (!found_?.isEmpty) { | |
| val ret = found_?.get._3 | |
| ret | |
| } else { | |
| // First we call generator which performs recursive processing (and fills handledInstances), | |
| // and only after that we read handledInstances.size. | |
| val genDone = gen | |
| val newIndex = handledInstances.size | |
| // If genDone contains Map.apply() (empty Map), don't generate "val xN". | |
| genDone match { | |
| case Apply(Select(Ident(_map), _apply), Nil) | |
| if _map.toString == "Map" && _apply.toString == "apply" => { genDone } | |
| case _ => { | |
| val ret = Ident("x" + newIndex) | |
| handledInstances += ((x, newIndex, ret, ValDef(Modifiers(), newTermName("x" + newIndex), TypeTree(), genDone))) | |
| ret | |
| } | |
| } | |
| } | |
| } | |
| // Let's put a bit more inside reifty: | |
| def getResultTree() = { | |
| assert(handledInstances.size > 0, "handledInstances.size > 0") | |
| Block( | |
| handledInstances.toList.sortBy(_._2).map(_._4), | |
| Ident("x" + (handledInstances.size - 1)) | |
| ) | |
| } | |
| null // last expression to ignore | |
| }.tree match { | |
| case Block(stats, expr) => stats // get block contents, need these visible in recursive(). | |
| } | |
| // Introduce variable agains buggy IDEA code hilight. | |
| val callRecursive: Tree = Apply(Ident("recursive"), List(o.tree)) | |
| def callHandle(gen: Tree) = Apply(Ident("handleInstance"), List(Ident("x"), gen)) | |
| Block( | |
| anyrefTrees ::: | |
| List( | |
| DefDef( | |
| Modifiers(), | |
| newTermName("recursive"), | |
| List(), | |
| List(List(ValDef(Modifiers(Flag.PARAM), newTermName("o"), TypeTree(typeOf[Any]), EmptyTree))), | |
| Ident(newTypeName("Tree")), | |
| Match( | |
| Ident("o"), | |
| typesToHandle.flatMap(t => { | |
| val bindX = Bind(newTermName("x"), Typed(Ident("_"), TypeTree(t))) | |
| def caseDef(tr: Tree) = CaseDef(bindX, EmptyTree, tr) | |
| if (isConstant(t)) { | |
| caseDef(_Literal(_Constant(Ident("x")))) :: Nil | |
| } else if (t <:< typeOf[xml.NodeSeq]) { | |
| CaseDef( | |
| Bind(newTermName("x"), Typed(Ident("_"), TypeTree(typeOf[xml.Text]))), | |
| EmptyTree, | |
| callHandle(_Apply( | |
| _Select(_Ident(_newTermName("xml")), _newTermName("Text")), | |
| _List(_Literal(_Constant(Select(Ident("x"), "data")))) | |
| )) | |
| ) :: | |
| CaseDef( | |
| Bind(newTermName("x"), Typed(Ident("_"), TypeTree(typeOf[xml.Elem]))), | |
| EmptyTree, | |
| // генерируем: Apply( | |
| // Select(Ident(xml), Elem), | |
| // List(null, label, xml.Null, xml.TopScpe, minimizeEmpty) ::: x.child.toList.map(v => recursive(v)) | |
| // ) | |
| // Because ::: is right-associative, we apply it to its 2nd operand. | |
| callHandle(_Apply( | |
| _Select(_Ident(_newTermName("xml")), _newTermName("Elem")), | |
| Apply( | |
| Select( | |
| Apply( | |
| Select(Select(Select(Ident("x"), "child"), "toList"), "map"), | |
| List(Function( | |
| List(ValDef(Modifiers(Flag.PARAM), newTermName("v"), TypeTree(), EmptyTree)), | |
| Apply(Ident("recursive"), List(Ident("v"))) | |
| )) | |
| ), | |
| "$colon$colon$colon" | |
| ), | |
| List(_List( | |
| _Literal(_Constant(Literal(Constant(null)))), | |
| _Literal(_Constant(Select(Ident("x"), "label"))), | |
| _Select(_Ident(_newTermName("xml")), _newTermName("Null")), | |
| _Select(_Ident(_newTermName("xml")), _newTermName("TopScope")), | |
| _Literal(_Constant(Select(Ident("x"), "minimizeEmpty"))) | |
| )) | |
| ) | |
| )) | |
| ) :: | |
| CaseDef( | |
| Bind(newTermName("x"), Typed(Ident("_"), TypeTree(typeOf[xml.NodeSeq]))), | |
| EmptyTree, | |
| // multiple siblings on top level, generate: | |
| // Apply( | |
| // Select(Select(Ident("xml"), "NodeSeq"), "fromSeq"), | |
| // List(Apply( | |
| // Select(Ident("Seq", "apply")), | |
| // x.toList.map(v => recursive(v)) | |
| // )) | |
| // ) | |
| callHandle(_Apply( | |
| _Select(_Select(_Ident(_newTermName("xml")), _newTermName("NodeSeq")), _newTermName("fromSeq")), | |
| _List(_Apply( | |
| _Select(_Ident(_newTermName("Seq")), _newTermName("apply")), | |
| Apply( | |
| Select(Select(Ident("x"), "toList"), "map"), | |
| List(Function( | |
| List(ValDef(Modifiers(Flag.PARAM), newTermName("v"), TypeTree(), EmptyTree)), | |
| Apply(Ident("recursive"), List(Ident("v"))) | |
| )) | |
| ) | |
| )) | |
| )) | |
| ) :: Nil | |
| } else if (t <:< weakTypeOf[Option[_]]) { | |
| CaseDef( | |
| Ident("None"), | |
| EmptyTree, | |
| _Ident(_newTermName("None")) | |
| ) :: | |
| CaseDef( | |
| Bind(newTermName("x"), Typed(Ident("_"), TypeTree(weakTypeOf[Some[_]]))), | |
| EmptyTree, | |
| callHandle(_Apply( | |
| _Select(_Ident(_newTermName("Some")), _newTermName("apply")), | |
| _List(Apply(Ident("recursive"), List(Select(Ident("x"), "get")))) | |
| )) | |
| ) :: Nil | |
| } else if (t =:= weakTypeOf[List[_]]) { | |
| // generate: Apply(Select(Ident("List"), "apply"), x.map(v => recursive(v))) | |
| caseDef( | |
| callHandle(_Apply( | |
| _Select(_Ident(_newTermName("List")), _newTermName("apply")), | |
| Apply(Select(Ident("x"), "map"), List( | |
| Function( | |
| List(ValDef(Modifiers(Flag.PARAM), newTermName("v"), TypeTree(), EmptyTree)), | |
| Apply(Ident("recursive"), List(Ident("v"))) | |
| ) | |
| )) | |
| )) | |
| ) :: Nil | |
| } else if (t =:= weakTypeOf[Map[_, _]]) { | |
| // generate: Apply(Select(Ident("Map"), "apply"), x.map(v => Tuple2(recursive(v._1), recursive(v._2)).toList) | |
| caseDef( | |
| callHandle(_Apply( | |
| _Select(_Ident(_newTermName("Map")), _newTermName("apply")), | |
| Apply(Select(Select(Ident("x"), "toList"), "map"), List( | |
| Function( | |
| List(ValDef(Modifiers(Flag.PARAM), newTermName("v"), TypeTree(), EmptyTree)), | |
| _Apply( | |
| _Select(_Ident(_newTermName("Tuple2")), _newTermName("apply")), | |
| _List( | |
| Apply(Ident("recursive"), List(Select(Ident("v"), "_1"))), | |
| Apply(Ident("recursive"), List(Select(Ident("v"), "_2"))) | |
| ) | |
| ) | |
| ) | |
| )) | |
| )) | |
| ) :: Nil | |
| } else if (isNonAbstractCaseClass(t)) { | |
| // For non-abstract case classes, | |
| // generate: { New(TypeTree(t1), List(List(recursive(x.arg1), recursive(x.arg2), ...))) } | |
| val s = t.typeSymbol.asClass | |
| val pp = t.declaration(nme.CONSTRUCTOR).asMethod.paramss | |
| assert(pp.size == 1, "Class must have one parameter list: " + s.fullName) | |
| caseDef( | |
| callHandle(Apply(Ident("New"), List( | |
| Apply(Ident("TypeTree"), List(TypeApply(Ident("typeOf"), List(Ident(t.typeSymbol))))), | |
| _List(_List(pp.head.map(p => | |
| Apply(Ident("recursive"), List(Select(Ident("x"), p.name))) | |
| ): _*)) | |
| ))) | |
| ) :: Nil | |
| } else if (t.typeSymbol.isModuleClass) { | |
| val s = t.typeSymbol | |
| CaseDef( | |
| mh.treeDottedName(s.fullName), | |
| EmptyTree, { | |
| // Аналог mh.treeDottedName, но с префиксами '_'. | |
| def impl(ss: List[String]): Tree = { | |
| if (ss.size == 1) _Ident(_newTermName(ss.head)) | |
| else _Select(impl(ss.tail), _newTermName(ss.head)) | |
| } | |
| impl(s.fullName.split('.').toList.reverse) | |
| } | |
| ) :: Nil | |
| } else { | |
| throw new Exception("Unsupported type: " + t.typeSymbol.fullName) | |
| } | |
| }) ::: | |
| CaseDef( | |
| Bind(newTermName("x"), Ident("_")), | |
| EmptyTree, | |
| mh.treeThrow(mh.treeConcatStrings( | |
| Literal(Constant("Unsupported type: ")), | |
| Apply(Select(Apply(Select(Ident("x"), "getClass"), Nil), "getCanonicalName"), Nil) | |
| )) | |
| ) :: Nil | |
| ) | |
| ), | |
| callRecursive | |
| ), | |
| Apply( | |
| TypeApply(Select(Ident("c"), "Expr"), List(TypeTree(c.weakTypeOf[O]))), | |
| // Have to generate asInstanceOf, otherwise it cannot cast c2.Block to c.Tree. | |
| List(TypeApply( | |
| Select(Apply(Ident("getResultTree"), List()), "asInstanceOf"), | |
| List(Select(Select(Ident(newTermName("c")), newTermName("universe")), newTypeName("Tree"))) | |
| )) | |
| ) | |
| ) | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Original discussion (with scaladoc & usage example) is here: https://groups.google.com/forum/#!topic/scala-user/BSHLDglW0OE