Skip to content

Instantly share code, notes, and snippets.

@leifwickland
Created October 19, 2018 19:36
Show Gist options
  • Save leifwickland/96091536f1d7f09bbf348747cda0e0a8 to your computer and use it in GitHub Desktop.
Save leifwickland/96091536f1d7f09bbf348747cda0e0a8 to your computer and use it in GitHub Desktop.
Creates a Play JSON Format instance for sealed families (ADTs) with `case class`es and optionally `case object`s
import play.api.libs.json._
import scalaz.NonEmptyList
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
/**
* Provides a Play JSON Format for an ADT with case classes and possibly case object members.
*
* Heavily inspired by https://gist.github.com/realpeterz/aecb53a67fb723485eb66d544a67d580
*
* Use this by extending an `object` with AdtFormat[A] and then implementing the requisite method.
*
* @tparam A Base type of the ADT
*/
trait AdtFormat[A <: Product with Serializable] {
/**
* You must implement this in your derived object using `singleton` and `nonSingletons`.
* Make it a `def` to avoid init order sadness; do not write `val members`.
*/
def members: NonEmptyList[Member]
/**
* The name that the JSON field which indicates which member of the ADT is specified.
* You may override this to change the name of the field that the member name is written to.
* Make it a `def` to avoid init order sadness; do not write `val typeFieldName `.
*
* NOTE: You must ensure this does not conflict field names in any of the members of A.
*/
protected def typeJsonFieldName = "type"
/** Create a singleton Member with its default type name */
protected def singleton[D <: A](d: D): Member = Singleton(lowerInitial(d.productPrefix), d)
/** Create a singleton Member with the given type name */
protected def singleton[D <: A](typeName: String, d: D): Member = Singleton(typeName, d)
/** Create a nonSingleton Member with a given name. */
protected def nonSingleton[D <: A: TypeTag: ClassTag](name: String, format: Format[D]): Member =
NonSingleton(name)(format)
/** Create a nonSingleton Member with a name automatically generated from the type name. */
protected def nonSingleton[D <: A: ClassTag](f: Format[D])(implicit t: TypeTag[D]): Member = {
val longName = t.tpe.toString
val lastDot = longName.lastIndexOf('.')
val typeName = if (lastDot < 0) longName else longName.drop(lastDot + 1)
NonSingleton(lowerInitial(typeName))(f)
}
protected sealed abstract class Member extends Product with Serializable {
def typeName: String
def readFunc: JsValue => JsResult[A]
def writeFunc: PartialFunction[A, JsValue]
}
private case class NonSingleton[D <: A: ClassTag: TypeTag](typeName: String)(format: Format[D])
extends Member {
val readFunc: JsValue => JsResult[A] = format.reads
private val objWithType = JsObject.empty + (typeJsonFieldName -> JsString(typeName)) // Allocate once
def writeFunc: PartialFunction[A, JsValue] = {
case d: D =>
format.writes(d) match {
case o: JsObject => objWithType ++ o
case _ =>
val m = s"Format[${implicitly[TypeTag[D]].tpe}] did not produce a JsObject " +
"which should be impossible because the type is guaranteed to be a Product"
throw new Exception(m)
}
}
}
private case class Singleton[D <: A](typeName: String, d: D) extends Member {
val readFunc: JsValue => JsSuccess[A] = Function.const(JsSuccess(d))
private val obj = JsObject.empty + (typeJsonFieldName -> JsString(typeName)) // Allocate once
def writeFunc: PartialFunction[A, JsValue] = { case `d` => obj }
}
private def reads: Reads[A] = {
val typeNameToReadFunc = members.list.toList.map { m =>
m.typeName -> m.readFunc
}.toMap
Reads { jsValue: JsValue =>
readTypeName(jsValue).flatMap { typeName =>
typeNameToReadFunc.get(typeName) match {
case None => typeNameUnknownError(typeName)
case Some(readF) => readF(jsValue)
}
}
}
}
private def writes: Writes[A] =
Writes(members.list.toList.map(_.writeFunc).reduce(_ orElse _))
private def lowerInitial(s: String): String = {
if (s.isEmpty || s.head.isLower) s
else Character.toLowerCase(s.charAt(0)) + s.substring(1)
}
private def typeUndefinedError = JsError(
JsonValidationError(Seq(s"'$typeJsonFieldName' field undefined")))
private def typeNameUnknownError(typeName: String) = JsError(
JsonValidationError(Seq(s"'$typeJsonFieldName' field has an unknown value of '$typeName'")))
private def readTypeName: JsValue => JsResult[String] =
(JsPath \ typeJsonFieldName).read[String].reads(_).orElse(typeUndefinedError)
// This must remain at the bottom to avoid init order sadness
implicit val format: Format[A] = Format[A](reads, writes)
// NOTE: do not put anything below `format`
}
import org.scalacheck._
import org.specs2.mutable.Specification
import org.specs2.ScalaCheck
import play.api.libs.json.{Format, Json, JsPath}
class AdtFormatSpec extends Specification with ScalaCheck {
"default Format for MyAdt" should {
"round-trip for all" in prop { m: MyAdt =>
Json.parse(Json.toJson(m).toString).validate[MyAdt].asOpt must beSome(m)
}
}
"Format with custom names" should {
import MyAdt._
implicit val format: Format[MyAdt] = new AdtFormat[MyAdt] {
override def typeJsonFieldName: String = "kind"
def members = scalaz.NonEmptyList(
nonSingleton("history", Json.format[CaseClass1]),
nonSingleton("math", Json.format[CaseClass2]),
singleton("civics", Enum1),
singleton("english", Enum2)
)
}.format
"have the expected names" in {
def typeNameFromJson(m: MyAdt): String =
(JsPath \ "kind").read[String].reads(Json.toJson(m)).get
typeNameFromJson(CaseClass1("hi")) ==== "history"
typeNameFromJson(CaseClass2(1)) ==== "math"
}
"round-trip for all" in prop { m: MyAdt =>
Json.parse(Json.toJson(m).toString).validate[MyAdt].asOpt must beSome(m)
}
}
}
object AdtFormatSpec {
sealed trait MyAdt extends Product with Serializable
object MyAdt extends AdtFormat[MyAdt] with ArbViaGen[MyAdt] {
case object Enum1 extends MyAdt
case object Enum2 extends MyAdt
case class CaseClass1(s: String) extends MyAdt
case class CaseClass2(i: Int) extends MyAdt
def members = scalaz.NonEmptyList(
nonSingleton(Json.format[CaseClass1]),
nonSingleton(Json.format[CaseClass2]),
singleton(Enum1),
singleton(Enum2)
)
val gen: Gen[MyAdt] = Gen.oneOf(
Gen.const(Enum1),
Gen.const(Enum2),
Gen.choose[Int](0, 999).map(_.toString).map(CaseClass1.apply),
Gen.choose[Int](0, 999).map(CaseClass2.apply)
)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment