Created
October 19, 2018 19:36
-
-
Save leifwickland/96091536f1d7f09bbf348747cda0e0a8 to your computer and use it in GitHub Desktop.
Creates a Play JSON Format instance for sealed families (ADTs) with `case class`es and optionally `case object`s
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import play.api.libs.json._ | |
import scalaz.NonEmptyList | |
import scala.reflect.ClassTag | |
import scala.reflect.runtime.universe.TypeTag | |
/** | |
* Provides a Play JSON Format for an ADT with case classes and possibly case object members. | |
* | |
* Heavily inspired by https://gist.github.com/realpeterz/aecb53a67fb723485eb66d544a67d580 | |
* | |
* Use this by extending an `object` with AdtFormat[A] and then implementing the requisite method. | |
* | |
* @tparam A Base type of the ADT | |
*/ | |
trait AdtFormat[A <: Product with Serializable] { | |
/** | |
* You must implement this in your derived object using `singleton` and `nonSingletons`. | |
* Make it a `def` to avoid init order sadness; do not write `val members`. | |
*/ | |
def members: NonEmptyList[Member] | |
/** | |
* The name that the JSON field which indicates which member of the ADT is specified. | |
* You may override this to change the name of the field that the member name is written to. | |
* Make it a `def` to avoid init order sadness; do not write `val typeFieldName `. | |
* | |
* NOTE: You must ensure this does not conflict field names in any of the members of A. | |
*/ | |
protected def typeJsonFieldName = "type" | |
/** Create a singleton Member with its default type name */ | |
protected def singleton[D <: A](d: D): Member = Singleton(lowerInitial(d.productPrefix), d) | |
/** Create a singleton Member with the given type name */ | |
protected def singleton[D <: A](typeName: String, d: D): Member = Singleton(typeName, d) | |
/** Create a nonSingleton Member with a given name. */ | |
protected def nonSingleton[D <: A: TypeTag: ClassTag](name: String, format: Format[D]): Member = | |
NonSingleton(name)(format) | |
/** Create a nonSingleton Member with a name automatically generated from the type name. */ | |
protected def nonSingleton[D <: A: ClassTag](f: Format[D])(implicit t: TypeTag[D]): Member = { | |
val longName = t.tpe.toString | |
val lastDot = longName.lastIndexOf('.') | |
val typeName = if (lastDot < 0) longName else longName.drop(lastDot + 1) | |
NonSingleton(lowerInitial(typeName))(f) | |
} | |
protected sealed abstract class Member extends Product with Serializable { | |
def typeName: String | |
def readFunc: JsValue => JsResult[A] | |
def writeFunc: PartialFunction[A, JsValue] | |
} | |
private case class NonSingleton[D <: A: ClassTag: TypeTag](typeName: String)(format: Format[D]) | |
extends Member { | |
val readFunc: JsValue => JsResult[A] = format.reads | |
private val objWithType = JsObject.empty + (typeJsonFieldName -> JsString(typeName)) // Allocate once | |
def writeFunc: PartialFunction[A, JsValue] = { | |
case d: D => | |
format.writes(d) match { | |
case o: JsObject => objWithType ++ o | |
case _ => | |
val m = s"Format[${implicitly[TypeTag[D]].tpe}] did not produce a JsObject " + | |
"which should be impossible because the type is guaranteed to be a Product" | |
throw new Exception(m) | |
} | |
} | |
} | |
private case class Singleton[D <: A](typeName: String, d: D) extends Member { | |
val readFunc: JsValue => JsSuccess[A] = Function.const(JsSuccess(d)) | |
private val obj = JsObject.empty + (typeJsonFieldName -> JsString(typeName)) // Allocate once | |
def writeFunc: PartialFunction[A, JsValue] = { case `d` => obj } | |
} | |
private def reads: Reads[A] = { | |
val typeNameToReadFunc = members.list.toList.map { m => | |
m.typeName -> m.readFunc | |
}.toMap | |
Reads { jsValue: JsValue => | |
readTypeName(jsValue).flatMap { typeName => | |
typeNameToReadFunc.get(typeName) match { | |
case None => typeNameUnknownError(typeName) | |
case Some(readF) => readF(jsValue) | |
} | |
} | |
} | |
} | |
private def writes: Writes[A] = | |
Writes(members.list.toList.map(_.writeFunc).reduce(_ orElse _)) | |
private def lowerInitial(s: String): String = { | |
if (s.isEmpty || s.head.isLower) s | |
else Character.toLowerCase(s.charAt(0)) + s.substring(1) | |
} | |
private def typeUndefinedError = JsError( | |
JsonValidationError(Seq(s"'$typeJsonFieldName' field undefined"))) | |
private def typeNameUnknownError(typeName: String) = JsError( | |
JsonValidationError(Seq(s"'$typeJsonFieldName' field has an unknown value of '$typeName'"))) | |
private def readTypeName: JsValue => JsResult[String] = | |
(JsPath \ typeJsonFieldName).read[String].reads(_).orElse(typeUndefinedError) | |
// This must remain at the bottom to avoid init order sadness | |
implicit val format: Format[A] = Format[A](reads, writes) | |
// NOTE: do not put anything below `format` | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.scalacheck._ | |
import org.specs2.mutable.Specification | |
import org.specs2.ScalaCheck | |
import play.api.libs.json.{Format, Json, JsPath} | |
class AdtFormatSpec extends Specification with ScalaCheck { | |
"default Format for MyAdt" should { | |
"round-trip for all" in prop { m: MyAdt => | |
Json.parse(Json.toJson(m).toString).validate[MyAdt].asOpt must beSome(m) | |
} | |
} | |
"Format with custom names" should { | |
import MyAdt._ | |
implicit val format: Format[MyAdt] = new AdtFormat[MyAdt] { | |
override def typeJsonFieldName: String = "kind" | |
def members = scalaz.NonEmptyList( | |
nonSingleton("history", Json.format[CaseClass1]), | |
nonSingleton("math", Json.format[CaseClass2]), | |
singleton("civics", Enum1), | |
singleton("english", Enum2) | |
) | |
}.format | |
"have the expected names" in { | |
def typeNameFromJson(m: MyAdt): String = | |
(JsPath \ "kind").read[String].reads(Json.toJson(m)).get | |
typeNameFromJson(CaseClass1("hi")) ==== "history" | |
typeNameFromJson(CaseClass2(1)) ==== "math" | |
} | |
"round-trip for all" in prop { m: MyAdt => | |
Json.parse(Json.toJson(m).toString).validate[MyAdt].asOpt must beSome(m) | |
} | |
} | |
} | |
object AdtFormatSpec { | |
sealed trait MyAdt extends Product with Serializable | |
object MyAdt extends AdtFormat[MyAdt] with ArbViaGen[MyAdt] { | |
case object Enum1 extends MyAdt | |
case object Enum2 extends MyAdt | |
case class CaseClass1(s: String) extends MyAdt | |
case class CaseClass2(i: Int) extends MyAdt | |
def members = scalaz.NonEmptyList( | |
nonSingleton(Json.format[CaseClass1]), | |
nonSingleton(Json.format[CaseClass2]), | |
singleton(Enum1), | |
singleton(Enum2) | |
) | |
val gen: Gen[MyAdt] = Gen.oneOf( | |
Gen.const(Enum1), | |
Gen.const(Enum2), | |
Gen.choose[Int](0, 999).map(_.toString).map(CaseClass1.apply), | |
Gen.choose[Int](0, 999).map(CaseClass2.apply) | |
) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment