Created
November 5, 2015 16:24
-
-
Save msiegenthaler/7b5eb8d833c0e0f74f3b to your computer and use it in GitHub Desktop.
Macro that auto-creates lifted functions for Free monads.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package free | |
import scala.annotation.{StaticAnnotation, compileTimeOnly} | |
import scala.language.experimental.macros | |
import scala.language.higherKinds | |
import scala.reflect.macros.whitebox | |
/** | |
* Usage: | |
* <pre> | |
* sealed trait Op[+A] | |
* case class MyOp(a: String) extends Op[Unit] | |
* @AddLiftingFunctions[Op]('Mon) object monadic | |
* import monadic._ | |
* val a: Mon[Unit] = myOp("hello") | |
* </pre> | |
* @tparam Op sealed trait of the Operations | |
*/ | |
@compileTimeOnly("enable macro paradise to expand macro annotations") | |
class AddLiftingFunctions[Op[_]](typeName: Symbol) extends StaticAnnotation { | |
def macroTransform(annottees: Any*): Any = macro FreeMacro.addLiftFunctionsAnnotation_impl | |
} | |
/** | |
* Usage: | |
* <pre> | |
* sealed trait Op[+A] | |
* case class MyOp(a: String) extends Op[Unit] | |
* val monadic = FreeMacro.liftFunctions[Op]('Mon) | |
* import monadic._ | |
* val a: Mon[Unit] = myOp("hello") | |
* </pre> | |
*/ | |
object FreeMacro { | |
def liftFunctions[F[_]](typeName: Symbol): Any = macro FreeMacro.liftFunctions_impl[F] | |
def liftFunctionsVampire[F[_]](typeName: Symbol): Any = macro FreeMacro.liftFunctionsVampire_impl[F] | |
} | |
//Private stuff below | |
//Vampire-body, see http://meta.plasm.us/posts/2013/07/12/vampire-methods-for-structural-types/ | |
class vampire(tree: Any) extends StaticAnnotation | |
class FreeMacro(val c: whitebox.Context) { | |
import c.universe._ | |
def liftFunctions_impl[F[_]](typeName: Expr[Any])(implicit t: c.WeakTypeTag[F[_]]) = | |
generateAnonClass[F](typeName, false) | |
def liftFunctionsVampire_impl[F[_]](typeName: Expr[Any])(implicit t: c.WeakTypeTag[F[_]]) = | |
generateAnonClass[F](typeName, true) | |
private def generateAnonClass[F[_]](typeNameExpr: Expr[Any], vampire: Boolean)(implicit t: c.WeakTypeTag[F[_]]) = { | |
val Apply(_, Literal(Constant(typeName: String)) :: Nil) = typeNameExpr.tree | |
val mod = generate(TermName(typeName), t.tpe.typeSymbol, false) | |
c.Expr(q"new { ..$mod }") | |
} | |
def addLiftFunctionsAnnotation_impl(annottees: Expr[Any]*): Expr[Any] = { | |
val q"new $_[$opIdent](${typeNameTree: Tree}).macroTransform(..$_)" = c.macroApplication | |
val opBase = c.typecheck(q"???.asInstanceOf[$opIdent[Unit]]").tpe.typeSymbol | |
val Apply(_, Literal(Constant(typeNameString: String)) :: Nil) = typeNameTree | |
val typeName = TermName(typeNameString) | |
val mod = annottees.map(_.tree).toList match { | |
case ClassDef(mods, name, tparams, Template(parents, self, body)) :: rest ⇒ //class/trait | |
val (initBody, restBody) = body.splitAt(1) | |
val t2 = Template(parents, self, initBody ++ generate(typeName, opBase) ++ restBody) | |
ClassDef(mods, name, tparams, t2) :: rest | |
case ModuleDef(mods, name, Template(parents, self, body)) :: rest ⇒ // object | |
val t2 = Template(parents, self, generate(typeName, opBase) ++ body) | |
ModuleDef(mods, name, t2) :: rest | |
case a :: rest ⇒ | |
c.abort(c.enclosingPosition, "AddLiftingFunctions annotation only supported on classes and objects") | |
} | |
c.Expr(q"..$mod") | |
} | |
private def generate(name: Name, opBase: Symbol, useVampire: Boolean = false): List[Tree] = { | |
val freeTypeName = name.toTypeName | |
val freeTypeTree = | |
q"""type $freeTypeName[A] = | |
_root_.cats.free.Free[({ type λ[α] = _root_.cats.free.Coyoneda[${opBase.asType}, α]})#λ, A]""" | |
val monadDeclTree = | |
q"""implicit val monad = | |
_root_.cats.free.Free.freeMonad[({ type λ[α] = _root_.cats.free.Coyoneda[${opBase.asType}, α]})#λ]""" | |
val opClass = opBase.asClass | |
if (!opClass.isSealed) | |
c.abort(c.enclosingPosition, s"The base class ${opBase.name} of the free monad is not sealed") | |
if (opClass.knownDirectSubclasses.isEmpty) | |
c.abort(c.enclosingPosition, s"The base class ${opBase.name} of the free monad has no subclasses. " + | |
s"If you're sure you have subclasses ans use @AddLiftingFunctions then this is a compilation order problem. " + | |
s"In that case please use FreeMonad.liftFunctions.") | |
val functions = opClass.knownDirectSubclasses.toList.map { | |
case s: ClassSymbol ⇒ | |
forImplementation(opBase.asType.toType, freeTypeName, useVampire)(s) | |
} | |
freeTypeTree :: monadDeclTree :: functions | |
} | |
/** Creates "myOp(text: String): FT[Unit]" from "case class MyOp(text: String)" */ | |
private def forImplementation(base: Type, freeType: TypeName, useVampire: Boolean)(opImpl: ClassSymbol): Tree = { | |
// inspired by https://gist.github.com/travisbrown/43c9dc072bfb2bba2611 | |
val name = TermName(classNameFunctionName(opImpl.name.toString)) | |
//TODO handle MyOperation[A] extends Op[A] | |
val A = opImpl.typeSignature.baseType(base.typeSymbol).typeArgs.head | |
val companion = opImpl.companion | |
val params = caseClassFields(opImpl.typeSignature) | |
if (useVampire) { | |
val paramDefs = params.zipWithIndex.map { | |
case ((_, tpe), index) ⇒ | |
val name = TermName("in" + (index + 1)) | |
q"""$name: $tpe""" | |
} | |
if (paramDefs.size > 5) c.abort(c.enclosingPosition, s"More parameters in ${companion.name} than supported " + | |
"by the FreeMacro. Please tell the maintainer to extend it.") | |
val vampire = TermName(s"vampire${paramDefs.size}_impl") | |
q"""@_root_.free.vampire($companion) | |
def $name(..$paramDefs): $freeType[$A] = macro _root_.free.FreeMacro.$vampire""" | |
} else { | |
val paramNames = params.map(_._1) | |
val paramDefs = params.map { p ⇒ q"""${p._1}: ${p._2}""" } | |
q"""def $name(..$paramDefs): $freeType[$A] = _root_.free.Free2.liftFC($companion(..$paramNames))""" | |
} | |
} | |
//Vampire Methods to avoid structural type warning | |
def vampire0_impl() = | |
q"_root_.free.Free2.liftFC($companionFromVampire())" | |
def vampire1_impl(in1: Expr[Any]) = | |
q"_root_.free.Free2.liftFC($companionFromVampire($in1))" | |
def vampire2_impl(in1: Expr[Any], in2: Expr[Any]) = | |
q"_root_.free.Free2.liftFC($companionFromVampire($in1, $in2))" | |
def vampire3_impl(in1: Expr[Any], in2: Expr[Any], in3: Expr[Any]) = | |
q"_root_.free.Free2.liftFC($companionFromVampire($in1, $in2, $in3))" | |
def vampire4_impl(in1: Expr[Any], in2: Expr[Any], in3: Expr[Any], in4: Expr[Any]) = | |
q"_root_.free.Free2.liftFC($companionFromVampire($in1, $in2, $in3, $in4))" | |
def vampire5_impl(in1: Expr[Any], in2: Expr[Any], in3: Expr[Any], in4: Expr[Any], in5: Expr[Any]) = | |
q"_root_.free.Free2.liftFC($companionFromVampire($in1, $in2, $in3, $in4, $in5))" | |
private def companionFromVampire = macroAnnotation[vampire].tree.children.tail.head | |
/** Current macro Annotation. */ | |
private def macroAnnotation[T](implicit t: WeakTypeTag[T]): Annotation = { | |
c.macroApplication.symbol.annotations.filter( | |
_.tree.tpe <:< t.tpe | |
).headOption.getOrElse(c.abort(c.enclosingPosition, s"Annotation ${t.tpe.typeSymbol.name} not found.")) | |
} | |
/** Converts MyOperation to myOperation */ | |
private def classNameFunctionName(className: String): String = className.head.toLower + className.tail | |
/** Extracts [(text, String), (number, Int) from "case class MyClass(text: String, number: Int)" */ | |
private def caseClassFields(tpe: Type): Iterable[(TermName, Type)] = { | |
tpe.decls.collect { | |
case accessor: MethodSymbol if accessor.isCaseAccessor ⇒ | |
accessor.typeSignature match { | |
case NullaryMethodType(returnType) ⇒ (accessor.name, returnType) | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment