Last active
June 15, 2016 01:14
-
-
Save yilinwei/27a5b22c346c1b6957041957adf96d90 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package iliad | |
package platform | |
import scala.annotation.StaticAnnotation | |
import scala.language.experimental.macros | |
final class free[T] extends StaticAnnotation { | |
def macroTransform(annottees: Any*): Any = macro FreeMacro.mkTransform | |
} | |
import scala.reflect.macros.whitebox | |
final class FreeMacro(val c: whitebox.Context) extends SymbolMacro { | |
import c.universe._ | |
lazy val _objectMethods: Set[Symbol] = weakTypeOf[Object].members.toSet | |
case class CaseMethod(uniqueName: String, methodName: TermName, method: MethodSymbol, returnUnit: Boolean) | |
def nextParam(m: MethodSymbol): Tree = if(m.returnType == typeOf[Unit]) q"val next: Next" else q"val onNext: ${m.returnType} => Next" | |
def asArgs(m: MethodSymbol): Seq[Tree] = m.paramLists.flatMap(_.map(p => q"${p.asTerm.name}")) | |
def nextArg(m: MethodSymbol): Tree = if(m.returnType == typeOf[Unit]) q"next" else q"onNext" | |
def unapplyParams(n: Int): List[Bind] = | |
//We can't use cq here because the wildcard for the unapply needs to be bounded. | |
(0 to n).map(n => Bind(TermName("arg" + n), Ident(termNames.WILDCARD))).toList | |
def rangeParam(n: Int): Tree = { | |
val name = TermName("arg" + n) | |
q"$name" | |
} | |
def rangeParams(n: Int): List[Tree] = (0 to n).map(rangeParam).toList | |
//Have to make case statements, functor and execute instance | |
def caseMethodsFold(single: CaseMethod => Tree, poly: CaseMethod => Tree, polyAggregate: (Seq[CaseMethod], Seq[Tree]) => Seq[Tree])(methodsByName: Map[TermName, List[MethodSymbol]]): Seq[Tree] = { | |
methodsByName.flatMap { case (tn, ms) => | |
val isPoly = ms.size > 1 | |
val mn = tn.decodedName.toString | |
if(isPoly) { | |
val polyCases = ms.zipWithIndex.map { case (m, idx) => | |
val ru = m.returnType == typeOf[Unit] | |
val cm = CaseMethod(s"$mn$idx", TermName(mn), m, ru) | |
(cm, single(cm)) | |
} | |
polyAggregate(polyCases.map(_._1), polyCases.map(_._2)) | |
} else { | |
val m = ms.head | |
val ru = m.returnType == typeOf[Unit] | |
Seq(single(CaseMethod(mn, TermName(mn), m, ru))) | |
} | |
}.toSeq | |
} | |
def capitalize(str: String): String = editFirst(_.toUpper, str) | |
def lowerFirst(str: String): String = editFirst(_.toLower, str) | |
def editFirst(f: Char => Char, str: String): String = f(str(0)) + str.tail | |
def caseDef(tpe: TypeName, cm: CaseMethod): Tree = { | |
val next = nextParam(cm.method) | |
val params = methodSymbolParamTree(cm.method) :+ next | |
q"case class ${TypeName(capitalize(cm.uniqueName))}[Next](..$params) extends $tpe[Next]" | |
} | |
def mkExecStatement(tpe: TypeName, cm: CaseMethod): Tree = { | |
val cs = TermName(capitalize(cm.uniqueName)) | |
val params = cm.method.paramLists.head.length | |
val bind = unapplyParams(params) | |
val unapply = pq"$cs(..$bind)" | |
val args = rangeParams(params - 1) | |
val last = rangeParam(params) | |
val next = if (cm.returnUnit) last else q"$last.apply(result)" | |
cq"""$unapply => | |
val result = runner.${TermName(cm.methodName.decodedName.toString)}(..$args) | |
$next | |
""" | |
} | |
def caseDefs(tpe: TypeName) = caseMethodsFold(caseDef(tpe, _), caseDef(tpe, _), (cases, trees) => { | |
val objectBody = cases.map { | |
case CaseMethod(un, mn, m, ru) => | |
val next = nextParam(m) | |
val params = methodSymbolParamTree(m) :+ next | |
val cn = TypeName(capitalize(un)) | |
q"def apply[Next](..$params): $cn[Next] = ${tpe.toTermName}.${cn.toTermName}.apply(..${asArgs(m) ++ Seq(nextArg(m))})" | |
} | |
trees ++ Seq( q""" | |
object ${TermName(capitalize(cases.head.methodName.decodedName.toString))} { | |
..$objectBody | |
} | |
""") | |
}) _ | |
def execDefs(tpe: TypeName) = caseMethodsFold(mkExecStatement(tpe, _), mkExecStatement(tpe, _), (_, trees) => trees ) _ | |
def mkFunctorStatement(cm: CaseMethod): Tree = { | |
val cs = TermName(capitalize(cm.uniqueName)) | |
val params = cm.method.paramLists.head.length | |
val bind = unapplyParams(params) | |
val unapply = pq"$cs(..$bind)" | |
val last = rangeParam(params) | |
val args = rangeParams(params - 1) :+ (if(cm.returnUnit) q"f($last)" else q"$last.andThen(f)") | |
cq"$unapply => $cs(..$args)" | |
} | |
def functorCases = caseMethodsFold(mkFunctorStatement, mkFunctorStatement, (_, trees) => trees) _ | |
def mkFree(tpt: TypeName, existingBody: List[Tree]): Tree = { | |
val tpe = annotatedType | |
/*Not entirely sure why, but vals generated from the type annotation don't seem | |
to have all the type information is expected, which means we need to check the paramLists...*/ | |
val methods = tpe.members.filter(t => t.isMethod && !_objectMethods.contains(t) && t.asMethod.paramLists.nonEmpty && t.name != TermName("$init$")) | |
.groupBy(_.name).map { case (mn, ms) => mn.toTermName -> ms.map(_.asMethod).toList } | |
q""" | |
trait $tpt[Next] { | |
def run(runner: $tpe): Next = ${tpt.toTermName}.run(this, runner) | |
def toFree: _root_.cats.free.Free[$tpt, Next] = _root_.cats.free.Free.liftF(this) | |
} | |
object ${tpt.toTermName} { | |
..${caseDefs(tpt)(methods)} | |
def run[A](fa: $tpt[A], runner: $tpe): A = fa match { | |
case ..${execDefs(tpt)(methods)} | |
} | |
def runner(runner: $tpe): _root_.cats.~>[$tpt, _root_.cats.Id] = new (_root_.cats.~>[$tpt, _root_.cats.Id]) { | |
def apply[A](fa: $tpt[A]): _root_.cats.Id[A] = run(fa, runner) | |
} | |
implicit val ${TermName(s"${lowerFirst(tpt.decodedName.toString)}Functor")}: _root_.cats.Functor[$tpt] = new _root_.cats.Functor[$tpt] { | |
def map[A, B](fa: $tpt[A])(f: A => B): $tpt[B] = fa match { case ..${functorCases(methods)} } | |
} | |
..$existingBody | |
} | |
""" | |
} | |
def mkTransform(annottees: Expr[Any]*): Tree = { | |
annottees.map(_.tree) match { | |
case List(q"abstract trait $tpt[$_]") => mkFree(tpt, List()) | |
case List(q"abstract trait $tpt[$_]", q"object $_ { ..$body }") => mkFree(tpt, body) | |
case _ => c.abort(c.enclosingPosition, "Cannot free trait as annotated type is not in the expected format") | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package iliad | |
import shapeless._ | |
import scala.annotation.implicitNotFound | |
import scala.reflect.ClassTag | |
@implicitNotFound(msg = "Cannot find PolyTC typeclass for ${TC} in list ${L}.") | |
trait PolyTC[L <: HList, F[_], TC[_[_]], A, B] { | |
def apply(a: F[A])(f: (F[A], TC[F]) => B): B | |
} | |
object PolyTC { | |
type Aux[L <: HList, F[_], TC[_[_]], A, B] = PolyTC[L, F, TC, A, B] | |
implicit def base[F[_], TC[_[_]], A, B]: Aux[HNil, F, TC, A, B] = new Aux[HNil, F, TC, A, B] { | |
def apply(a: F[A])(f: (F[A], TC[F]) => B): B = | |
throw new IllegalStateException(s"Could not find class matching $f within poly list!" ) | |
} | |
implicit def recurse[H[_] <: F[_], T <: HList, F[_], TC[_[_]], A, B](implicit tc: TC[H], tail: Aux[T, F, TC, A, B], ct: ClassTag[H[_]]): Aux[H[A] :: T, F, TC, A, B] = new Aux[H[A] :: T, F, TC, A, B] { | |
def apply(fa: F[A])(g: (F[A], TC[F]) => B): B = | |
if(fa.getClass == ct.runtimeClass) | |
g(fa, tc.asInstanceOf[TC[F]]) | |
else | |
tail(fa)(g) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment