Last active
November 8, 2015 01:45
-
-
Save philnguyen/8966d1a66fc61743374d to your computer and use it in GitHub Desktop.
Concise, type-safe and unboxed Enum by Scala macros
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.language.experimental.macros | |
import scala.reflect.macros.blackbox.Context | |
import scala.annotation.StaticAnnotation | |
import scala.annotation.compileTimeOnly | |
/** | |
This annotation turns a simple class declaration into a type-safe and unboxed Enum. | |
It does not attempt to be compatible with Java, because I have no respect for Java. | |
There is an example usage at the end of the file. | |
*/ | |
@compileTimeOnly("please enable macro paradise") | |
class ValueEnum extends StaticAnnotation { | |
def macroTransform(annottees: Any*): Any = macro EnumMacro.impl | |
} | |
object EnumMacro { | |
def impl(c: Context)(annottees: c.Expr[Any]*): c.Expr[Any] = { | |
import c.universe._ | |
// Extract class name and variant names | |
val (name, variants) = annottees.map(_.tree).toList match { | |
case ClassDef(_,name,_,Template(_,_,body)) :: Nil => | |
(name, (body filter (_.isInstanceOf[Ident])).asInstanceOf[List[Ident]]) | |
case _ => throw new Exception (s"Illegal input: $annottees\n") | |
} | |
val variantCount = variants.length | |
val variantMax = variantCount - 1 | |
val ord = q"ord" | |
val defToString = { | |
val cases = ((0 until variantCount) zip variants).map { | |
case (i,v) => cq"$i => ${v.toString}" | |
} | |
q"""override def toString: String = $ord match {case ..$cases}""" | |
} | |
val defInstances = | |
((0 until variantCount) zip variants).foldRight(q""){ | |
case ((i,v),q) => q"val $v = new $name($i); ..$q" | |
} | |
val defNext = q""" | |
def next: $name = | |
if ($ord < $variantMax) new $name($ord + 1) | |
else throw new IllegalArgumentException(${variants(variantMax)} + " does not have next") | |
""" | |
val defPrev = q""" | |
def prev: $name = | |
if ($ord > 0) new $name($ord - 1) | |
else throw new IllegalArgumentException(${variants(0)} + " does not have prev") | |
""" | |
val defns = q""" | |
class $name private (val $ord: Int) extends AnyVal { | |
import ${name.toTermName}._ | |
$defToString | |
$defNext | |
$defPrev | |
} | |
object ${name.toTermName} { | |
def count = $variantCount | |
..$defInstances | |
} | |
""" | |
//printf(s"Generated class:\n$defns\n") | |
c.Expr[Any](defns) | |
} | |
} | |
/** EXAMPLE: | |
// This declaration | |
@ValueEnum | |
class Day {Monday; Tuesday; Wednesday; Thursday; Friday; Saturday; Sunday} | |
// expands to the following: | |
class Day private (val ord: Int) extends AnyVal { | |
import Day._ | |
override def toString: String = ord match { | |
case 0 => "Monday" | |
case 1 => "Tuesday" | |
case 2 => "Wednesday" | |
case 3 => "Thursday" | |
case 4 => "Friday" | |
case 5 => "Saturday" | |
case 6 => "Sunday" | |
} | |
def next: Day = | |
if (ord < 6) new Day(ord + 1) | |
else throw new IllegalArgumentException(Sunday + " does not have next") | |
def prev: Day = | |
if (ord > 0) new Day(ord - 1) | |
else throw new IllegalArgumentException(Monday + " does not have prev") | |
} | |
object Day extends scala.AnyRef { | |
def count = 7 | |
val Monday = new Day(0) | |
val Tuesday = new Day(1) | |
val Wednesday = new Day(2) | |
val Thursday = new Day(3) | |
val Friday = new Day(4) | |
val Saturday = new Day(5) | |
val Sunday = new Day(6) | |
} | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment