Created
October 22, 2021 15:03
-
-
Save jodersky/3916b90a18a1ee2f60891ce82077bbb9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class FunctionExractor(using qctx: Quotes) { | |
import qctx.reflect._ | |
private case class ToReplace( | |
outerSymbol: Symbol, // where the container originally comes from, used to collect inputs to the dep | |
innerTpe: TypeRepr // the original container's partition type | |
) | |
/** | |
* Given | |
* | |
* { | |
* a1() | |
* a2() | |
* ... | |
* a3() | |
* } | |
* | |
* where aX <:< Container | |
* | |
* transforms to | |
* | |
* (a1, a2, a3) => a1...a2...a3 | |
* | |
* | |
*/ | |
private def extract(term: Term): (List[ToReplace], Block, TypeRepr) = { | |
class Collector() extends TreeAccumulator[List[ToReplace]] { | |
def foldTree(x: List[ToReplace], tree: Tree)(owner: Symbol): List[ToReplace] = tree match { | |
case term@Apply(Select(container, "apply"), _) if container.tpe <:< TypeRepr.of[Container[_]] => | |
//System.err.println("got one!") | |
ToReplace(container.symbol, term.tpe) :: x | |
case _ => | |
//System.err.println(tree) | |
super.foldOverTree(x, tree)(owner) | |
} | |
} | |
class Replacer(paramss: List[List[Tree]]) extends TreeMap { | |
private val params = paramss.flatten | |
private val it = params.iterator | |
override def transformTerm(tree: Term)(owner: Symbol): Term = tree match { | |
case Apply(Select(container, "apply"), _) if container.tpe <:< TypeRepr.of[Container[_]] => | |
Ref(it.next().symbol) | |
case _ => | |
super.transformTerm(tree)(owner) | |
} | |
} | |
val replacements = Collector().foldTree(Nil, term)(Symbol.spliceOwner).reverse | |
val paramNames = replacements.zipWithIndex.map{ | |
case (_, idx) => s"param$idx" | |
} | |
val mt = MethodType(paramNames)( | |
_ => replacements.map(_.innerTpe).toList, | |
_ => term.tpe | |
) | |
val methodSym = Symbol.newMethod( | |
Symbol.spliceOwner, | |
"extract", | |
mt | |
) | |
val dd = DefDef( | |
methodSym, | |
paramss => Some(Replacer(paramss).transformTerm(term)(Symbol.spliceOwner)) | |
) | |
val cl = Closure(Ref(dd.symbol), None) | |
val b = Block(List(dd), cl) | |
(replacements, b, term.tpe) | |
} | |
def impl(body: Expr[_]) = { | |
val (containerRefs, lambda, rtpe) = extract(body.asTerm) | |
System.err.println(rtpe.show) | |
System.err.println(lambda.asExpr.show) | |
lambda.asExpr | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment