Skip to content

Instantly share code, notes, and snippets.

@philnguyen
Last active November 8, 2015 01:45
Show Gist options
  • Save philnguyen/8966d1a66fc61743374d to your computer and use it in GitHub Desktop.
Save philnguyen/8966d1a66fc61743374d to your computer and use it in GitHub Desktop.
Concise, type-safe and unboxed Enum by Scala macros
import scala.language.experimental.macros
import scala.reflect.macros.blackbox.Context
import scala.annotation.StaticAnnotation
import scala.annotation.compileTimeOnly
/**
This annotation turns a simple class declaration into a type-safe and unboxed Enum.
It does not attempt to be compatible with Java, because I have no respect for Java.
There is an example usage at the end of the file.
*/
@compileTimeOnly("please enable macro paradise")
class ValueEnum extends StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro EnumMacro.impl
}
object EnumMacro {
def impl(c: Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
// Extract class name and variant names
val (name, variants) = annottees.map(_.tree).toList match {
case ClassDef(_,name,_,Template(_,_,body)) :: Nil =>
(name, (body filter (_.isInstanceOf[Ident])).asInstanceOf[List[Ident]])
case _ => throw new Exception (s"Illegal input: $annottees\n")
}
val variantCount = variants.length
val variantMax = variantCount - 1
val ord = q"ord"
val defToString = {
val cases = ((0 until variantCount) zip variants).map {
case (i,v) => cq"$i => ${v.toString}"
}
q"""override def toString: String = $ord match {case ..$cases}"""
}
val defInstances =
((0 until variantCount) zip variants).foldRight(q""){
case ((i,v),q) => q"val $v = new $name($i); ..$q"
}
val defNext = q"""
def next: $name =
if ($ord < $variantMax) new $name($ord + 1)
else throw new IllegalArgumentException(${variants(variantMax)} + " does not have next")
"""
val defPrev = q"""
def prev: $name =
if ($ord > 0) new $name($ord - 1)
else throw new IllegalArgumentException(${variants(0)} + " does not have prev")
"""
val defns = q"""
class $name private (val $ord: Int) extends AnyVal {
import ${name.toTermName}._
$defToString
$defNext
$defPrev
}
object ${name.toTermName} {
def count = $variantCount
..$defInstances
}
"""
//printf(s"Generated class:\n$defns\n")
c.Expr[Any](defns)
}
}
/** EXAMPLE:
// This declaration
@ValueEnum
class Day {Monday; Tuesday; Wednesday; Thursday; Friday; Saturday; Sunday}
// expands to the following:
class Day private (val ord: Int) extends AnyVal {
import Day._
override def toString: String = ord match {
case 0 => "Monday"
case 1 => "Tuesday"
case 2 => "Wednesday"
case 3 => "Thursday"
case 4 => "Friday"
case 5 => "Saturday"
case 6 => "Sunday"
}
def next: Day =
if (ord < 6) new Day(ord + 1)
else throw new IllegalArgumentException(Sunday + " does not have next")
def prev: Day =
if (ord > 0) new Day(ord - 1)
else throw new IllegalArgumentException(Monday + " does not have prev")
}
object Day extends scala.AnyRef {
def count = 7
val Monday = new Day(0)
val Tuesday = new Day(1)
val Wednesday = new Day(2)
val Thursday = new Day(3)
val Friday = new Day(4)
val Saturday = new Day(5)
val Sunday = new Day(6)
}
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment