Last active
August 22, 2023 19:54
-
-
Save chuwy/8f664c5e2ac57d4513702beb9e5f261e to your computer and use it in GitHub Desktop.
Scala macro for building type-safe SQL Fragments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.time.LocalDateTime | |
import scala.quoted.* | |
import scala.deriving.Mirror | |
import quotidian.* | |
import quotidian.syntax.* | |
import io.github.iltotore.iron.* | |
import cats.data.NonEmptyList | |
import skunk.* | |
import skunk.codec.numeric.int8 | |
import skunk.implicits.* | |
import io.foldables.ratio.common.Primitive | |
trait Table(val tableName: Table.Name, val columns: Table.Columns) extends Selectable: | |
/** Union type of all column names */ | |
type Names | |
/** Tuple of all column names */ | |
type Columns | |
/** Tuple of all columns, with literal type; typed later in `selectDynamic` */ | |
val all: Any | |
transparent inline def selectDynamic(name: String): Any = | |
if name == Table.ExceptMethodName then | |
(toExclude: List[String] | String) => | |
toExclude match | |
case str: String => Table.Columns(columns.get.filterNot(c => c.n == str)) | |
case list: List[String] => Table.Columns(columns.get.filterNot(c => list.contains(c.n))) | |
else if name == Table.SelectMethodName then | |
(toInclude: List[String]) => | |
Table.Columns(toInclude.flatMap(name => columns.get.find(tc => tc.n == name))) | |
else if name == Table.AllMethodName then | |
all | |
else columns.get.find(_.n == name).get | |
object in: | |
object count: | |
def f: Fragment[Void] = | |
sql"SELECT COUNT(*) FROM ${tableName.f}" | |
def q: Query[Void, Long] = | |
f.query(int8) | |
object Table: | |
case class Columns(get: NonEmptyList[TypedColumn[?]]): | |
def f: Fragment[Void] = | |
sql"#${get.toList.map(_.n).mkString(", ")}" | |
def as(short: String): List[String] = | |
get.map(_.n).toList.map(c => s"$short.$c") | |
object Columns: | |
def apply(list: List[TypedColumn[?]]): Columns = | |
Columns(NonEmptyList.fromListUnsafe(list)) | |
val ExceptMethodName: "except" = "except" | |
type Except[A] = A | List[A] => Columns | |
val AllMethodName: "all" = "all" | |
val SelectMethodName: "select" = "select" | |
type Select[A] = List[A] => Columns | |
final case class Config(service: Option[String], table: Option[String]): | |
def getTableName(typeName: String): String = | |
val t = table.getOrElse(typeName) | |
val path = service match | |
case Some(s) => NonEmptyList.of(s, t) | |
case None => NonEmptyList.one(t) | |
NameStrategy.FullSnake.transform(path) | |
object Config: | |
inline def default: Config = Config(None, None) | |
def apply(service: String, table: String): Config = | |
Config(Some(service), Some(table)) | |
enum NameStrategy: | |
case Snake | |
case FullSnake | |
case Camel | |
case FullCamel | |
def transform(path: NonEmptyList[String]): String = | |
this match | |
case Snake => | |
snakeCase(path.last) | |
case FullSnake => | |
path.toList.map(snakeCase).mkString("_") | |
case Camel => | |
path.last | |
case FullCamel => | |
path match | |
case NonEmptyList(head, tail) => | |
(head :: tail.map(_.capitalize)).mkString("") | |
final case class TypedColumn[C](n: String): | |
infix def eql(c: Encoder[C]): Fragment[C] = sql"#$n = $c" | |
def as(table: String): TypedColumn[C] = | |
this.copy(n = s"$table.$n") | |
def f: Fragment[Void] = | |
sql"#$n" | |
def currentTimestamp(using C =:= Option[LocalDateTime]): Fragment[Void] = | |
sql"#$n = current_timestamp" | |
def increment[A](using C =:= IronType[Int, A]): Fragment[Void] = | |
sql"#$n = #$n + 1" | |
opaque type Name = String | |
object Name: | |
def apply(str: String): Name = str | |
extension (name: Name) | |
def unbox: String = name | |
def f: Fragment[Void] = sql"#${name}" | |
def as(short: String): Fragment[Void] = sql"#${name} AS #${short}" | |
transparent inline def build[T](inline config: Config)(using Mirror.ProductOf[T]): Table = | |
${ buildFromBlockImpl[T]('config) } | |
def buildFromBlockImpl[T: Type](using quotes: Quotes)(configExpr: Expr[Config]): Expr[Table] = | |
import quotes.reflect.* | |
val mirror = MacroMirror.summonProduct[T] | |
val tableName = '{ ${configExpr}.getTableName(${Expr(mirror.label)}) } | |
val nameTypeMap = | |
def go(root: List[String])(ls: List[(String, TypeRepr)]): List[(NonEmptyList[String], TypeRepr)] = | |
ls.flatMap { (label, tpe) => | |
tpe.asType match | |
case '[t] => | |
Expr.summon[Primitive[t]] match | |
case Some(_) => | |
Some(NonEmptyList(label, root).reverse -> tpe) | |
case None => | |
MacroMirror.summon[t] match | |
case Right(pm: MacroMirror.ProductMacroMirror[quotes.type, t]) => | |
go(label :: root)(pm.elemLabels.zip(pm.elemTypes)) | |
case _ => | |
report.errorAndAbort(s"Couldn't synthesize instance neither for Primitive nor Table for ${tpe.show}") | |
} | |
val zipped = mirror.elemLabels.zip(mirror.elemTypes) | |
NonEmptyList.fromList(go(Nil)(zipped).map((path, tpe) => NameStrategy.Snake.transform(path) -> tpe)) match | |
case Some(nel) => nel | |
case None => report.errorAndAbort("Could not derive columns. A Table must contain at least one Column") | |
val columns = Expr.ofList(nameTypeMap.toList.map { (n, t) => | |
t.asType match | |
case '[tpe] => | |
val name = Expr(n) | |
'{ new TypedColumn[tpe](${name}) } | |
}) | |
val columnNames = nameTypeMap.map(_._1) | |
if (columnNames.distinct.length != columnNames.length) | |
then report.errorAndAbort(s"Not all column names are unique (${columnNames.toList.mkString(", ")})") | |
val columnName: TypeRepr = columnNames match | |
case NonEmptyList(a, b :: rest) => | |
rest | |
.foldLeft(OrType(ConstantType(StringConstant(a)), ConstantType(StringConstant(b)))) | |
.apply((orType, name) => OrType(orType, ConstantType(StringConstant(name)))) | |
case NonEmptyList(a, Nil) => | |
ConstantType(StringConstant(a)) | |
val columnNamesTuple = Expr.ofTupleFromSeq(columnNames.toList.map(name => Expr(name))) | |
val refinementWithColumns = nameTypeMap | |
.foldLeft(TypeRepr.of[Table]) | |
.apply { case (acc, (name, tpr)) => | |
tpr.asType match | |
case '[tpe] => | |
Refinement( | |
parent = acc, | |
name = name, | |
info = TypeRepr.of[TypedColumn].appliedTo(TypeRepr.of[tpe]) | |
) | |
} | |
val refinementWithExcept = Refinement(parent = refinementWithColumns, name = ExceptMethodName, TypeRepr.of[Table.Except].appliedTo(columnName)) | |
val refinementWithSelect = Refinement(parent = refinementWithExcept, name = SelectMethodName, TypeRepr.of[Table.Select].appliedTo(columnName)) | |
val refinementFinal = Refinement(parent = refinementWithSelect, name = AllMethodName, columnNamesTuple.asTerm.tpe) | |
(refinementFinal.asType, columnNamesTuple.asTerm.tpe.asType, columnName.asType) match | |
case ('[refinementType], '[allTuple], '[union]) => | |
'{ | |
(new Table(Name(${tableName}), Columns(NonEmptyList.fromListUnsafe(${columns}))) { | |
val all = ${columnNamesTuple} | |
}).asInstanceOf[Table { type Names = union; type Columns = allTuple } & refinementType] | |
} | |
def snakeCase(str: String): String = | |
str | |
.replaceAll("([A-Z]+)([A-Z][a-z])", "$1_$2") | |
.replaceAll("([a-z\\d])([A-Z])", "$1_$2") | |
.toLowerCase |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://gist.github.com/chuwy/8f664c5e2ac57d4513702beb9e5f261e#file-table-scala-L162
MacroMirror
has anelems
field that exposes a combination label/type, along with some other conveniences, which might be useful.