Last active
August 17, 2020 21:59
-
-
Save nafg/883814df176e0cec495429806a1e01f2 to your computer and use it in GitHub Desktop.
Simple code generator for Slick using Scalameta rather than strings, as an SBT plugin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.sql.Types | |
import scala.annotation.tailrec | |
import scala.concurrent.Await | |
import scala.concurrent.ExecutionContext.Implicits.global | |
import scala.concurrent.duration.Duration | |
import scala.meta._ | |
import slick.dbio.DBIO | |
import slick.jdbc.meta.{MColumn, MQName, MTable} | |
import slick.jdbc.{JdbcBackend, JdbcProfile} | |
import sbt.Keys._ | |
import sbt._ | |
object SlickMetaGenPlugin extends AutoPlugin { | |
override def requires = SlickConfigPlugin | |
object autoImport { | |
case class ColumnInfo(columnName: String, | |
tableFieldTerm: Term.Name, | |
rowFieldTerm: Term.Name, | |
scalaType: Type, | |
scalaDefault: Option[Term]) | |
case class SchemaInfo(tableName: MQName, tableClassName: String, rowClassName: String, columns: List[ColumnInfo]) | |
val slickPackage = settingKey[String]("The package to put the definitions in") | |
val slickContainer = | |
settingKey[String]("The container object / trait to put the table definitions in (also determines the filename)") | |
val slickProfileClass = settingKey[String]("The slick profile class") | |
val slickTables = settingKey[MTable => Boolean]("Which tables to codegen") | |
val slickSchemaInfo = taskKey[List[SchemaInfo]]("Generate the SchemaInfo") | |
val slickMetaGenModels = taskKey[Seq[File]]("Generate type definitions from the database") | |
val slickMetaGenTables = taskKey[Seq[File]]("Generate schema definitions from the database") | |
val slickMetaGenExtraImports = settingKey[List[String]]("Extra imports") | |
} | |
import SlickConfigPlugin.autoImport._ | |
import autoImport._ | |
def toCamel(s: String) = { | |
def loop(cs: List[Char]): List[Char] = | |
cs match { | |
case '_' :: c :: rest => c.toUpper :: loop(rest) | |
case c :: rest => c :: loop(rest) | |
case Nil => Nil | |
} | |
loop(s.toList).mkString | |
} | |
def getColumnInfo: MColumn => ColumnInfo = { | |
case c @ MColumn(_, name, sqlType, typeName, _, _, _, nullable, _, columnDef, _, _, _, _, _, isAutoInc) => | |
val defaultNotAuto = if (isAutoInc.contains(true)) None else columnDef | |
val (typ0, default0) = | |
(sqlType, typeName) match { | |
case (_, "lo") => t"java.sql.Blob" -> None | |
case (Types.NUMERIC, "numeric") => t"BigDecimal" -> defaultNotAuto.map(s => q"BigDecimal($s)") | |
case (Types.DOUBLE, "float8") => t"Double" -> defaultNotAuto.map(s => Lit.Double(s.toDouble)) | |
case (Types.BIT, "bool") => t"Boolean" -> defaultNotAuto.map(s => Lit.Boolean(s.toBoolean)) | |
case (Types.INTEGER, _) => t"Int" -> defaultNotAuto.map(s => Lit.Int(s.toInt)) | |
case (Types.VARCHAR, "varchar" | "text") => | |
t"String" -> defaultNotAuto.map(s => Lit.String(s.stripPrefix("'").stripSuffix("'"))) | |
case (Types.DATE, "date") => | |
t"java.time.LocalDate" -> | |
defaultNotAuto.collect { case "now()" => q"java.time.LocalDate.now()" } | |
case (_, _) => | |
System.err.println("Don't know how to handle " + c) | |
t"Nothing" -> None | |
} | |
val (typ, default) = | |
if (nullable.contains(true)) | |
t"Option[$typ0]" -> Some(default0.map(t => q"Some($t)").getOrElse(q"None")) | |
else | |
typ0 -> default0 | |
val ident = Term.Name(toCamel(name)) | |
ColumnInfo(name, ident, ident, typ, default) | |
} | |
def getSchemaInfo(table: MTable) = | |
table.getColumns.map(_.toList.map(getColumnInfo)) | |
.map { colInfos => | |
val ident = toCamel(table.name.name.capitalize) | |
SchemaInfo(table.name, ident, ident + "Row", colInfos) | |
} | |
def rowStats(schemaInfo: SchemaInfo): List[Stat] = { | |
val params = schemaInfo.columns.map { col => | |
Term.Param(Nil, col.rowFieldTerm, Some(col.scalaType), col.scalaDefault) | |
} | |
List( | |
q""" | |
@JsonCodec | |
case class ${Type.Name(schemaInfo.rowClassName)}(..$params) | |
""" | |
) | |
} | |
def isDefaultSchema(schema: String) = schema == "public" | |
def mkStar(rowClassName: String, columns: List[ColumnInfo]) = { | |
val companion = Term.Name(rowClassName) | |
val terms = columns.map(_.tableFieldTerm) | |
val numCols = columns.length | |
val (tuple, factory, extractor) = | |
if (numCols <= 22) | |
(Term.Tuple(terms), q"($companion.apply _).tupled", q"$companion.unapply") | |
else { | |
@tailrec | |
def group22[A](values: List[A])(group: List[A] => A): A = values match { | |
case List(one) => one | |
case _ => | |
val (first, second) = values.splitAt(22) | |
group22(group(first) +: second)(group) | |
} | |
(group22[Term](terms)(Term.Tuple(_)), | |
Term.PartialFunction( | |
List( | |
p""" | |
case ${group22[Pat](terms.map(Pat.Var(_)))(Pat.Tuple(_))} => | |
$companion(..$terms) | |
""" | |
) | |
), | |
q"(rec: ${Type.Name(rowClassName)}) => Some(${group22[Term](terms.map(t => q"rec.$t"))(Term.Tuple(_))})") | |
} | |
q"def * = $tuple.<>({$factory}, $extractor)" | |
} | |
def tableStats: SchemaInfo => List[Stat] = { | |
case SchemaInfo(tableName, tableClassName, rowClassName, columns) => | |
val fields = columns.map { | |
case ColumnInfo(columnName, tableFieldName, _, scalaType, _) => | |
q""" | |
val ${Pat.Var(tableFieldName)} = column[$scalaType]($columnName) | |
""" | |
} | |
val star = mkStar(rowClassName, columns) | |
val params = tableName match { | |
case MQName(None, Some(schema), name) if !isDefaultSchema(schema) => List(q"Some($schema)", Lit.String(name)) | |
case MQName(None, _, name) => List(Lit.String(name)) | |
case MQName(Some(_), _, _) => sys.error("catalog not supported") | |
} | |
List( | |
q""" | |
class ${Type.Name(tableClassName)}(_tableTag: Tag) | |
extends Table[${Type.Name(rowClassName)}](_tableTag, ..$params) { | |
$star | |
..$fields | |
} | |
""", | |
q""" | |
lazy val ${Pat.Var(Term.Name(tableClassName))} = TableQuery[${Type.Name(tableClassName)}] | |
""" | |
) | |
} | |
def toRef(s: String): Term.Ref = { | |
def loop(last: String, revInit: List[String]): Term.Ref = revInit match { | |
case Nil => Term.Name(last) | |
case x :: xs => Term.Select(loop(x, xs), Term.Name(last)) | |
} | |
val last :: revInit = s.split('.').toList.reverse | |
loop(last, revInit) | |
} | |
def imports(strings: List[String]): List[Stat] = | |
if (strings.isEmpty) | |
Nil | |
else | |
List(q"import ..${strings.map(_.parse[Importer].get)}") | |
override def projectSettings = | |
Seq( | |
slickContainer := "Tables", | |
slickProfileClass := slickConfig.value.getString("profile"), | |
slickTables := (_.name != MQName(None, None, "flyway_schema_history")), | |
slickMetaGenExtraImports := Nil, | |
slickSchemaInfo := { | |
val config = slickConfig.value | |
val profileName = slickProfileClass.value | |
val tablesPred = slickTables.value | |
val slickProfile = Class.forName(profileName).getField("MODULE$").get(null).asInstanceOf[JdbcProfile] | |
val db = JdbcBackend.Database.forConfig("", config) | |
try { | |
val tablesAction = slickProfile.defaultTables.map(_.filter(tablesPred)) | |
val infoAction = tablesAction.flatMap(tables => DBIO.sequence(tables.toList.map(getSchemaInfo))) | |
Await.result(db.run(infoAction), Duration.Inf) | |
} finally db.close() | |
}, | |
slickMetaGenModels := { | |
val container = slickContainer.value | |
val outputDir = (Compile / sourceManaged).value | |
val pkg = slickPackage.value | |
val schemaInfos = slickSchemaInfo.value | |
val filename = container + ".scala" | |
val file = outputDir / pkg.replace(".", "/") / filename | |
IO.write( | |
file, | |
q""" | |
package ${toRef(pkg)} { | |
import io.circe.generic.JsonCodec | |
..${imports(slickMetaGenExtraImports.value)} | |
..${schemaInfos.flatMap(rowStats)} | |
} | |
""".syntax | |
) | |
Seq[File](file) | |
}, | |
slickMetaGenTables := { | |
val container = slickContainer.value | |
val outputDir = (Compile / sourceManaged).value | |
val pkg = slickPackage.value | |
val slickProfileName = toRef(slickProfileClass.value.stripSuffix("$")) | |
val schemaInfos = slickSchemaInfo.value | |
val filename = container + ".scala" | |
val file = outputDir / pkg.replace(".", "/") / filename | |
IO.write( | |
file, | |
q""" | |
package ${toRef(pkg)} { | |
import $slickProfileName.api._ | |
..${imports(slickMetaGenExtraImports.value)} | |
object ${Term.Name(container)} { | |
..${schemaInfos.flatMap(tableStats)} | |
} | |
} | |
""".syntax | |
) | |
Seq[File](file) | |
} | |
) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment