Created
April 29, 2016 13:07
-
-
Save lancegatlin/247fadb3768ccf8e9636e4c9b546eae8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package s_mach.codetools.bigcaseclass | |
import s_mach.string._ | |
object BigCaseClassPrinter { | |
def print(name: String, fields: Vector[CaseClassField]) : String = { | |
val subCaseClasses = fields.grouped(CASECLASS_MAX_FIELDS).toVector | |
/* | |
_1 : Licensee.Licensee1, | |
_2 : Licensee.Licensee2 | |
*/ | |
val fieldsStr = subCaseClasses.zipWithIndex.map { case (s,i) => | |
Seq(s"_${i+1}",":",s"$name.$name${i+1}") | |
}.printGrid(" ",",\n") | |
/* | |
def licenseeId = _1.licenseeId | |
def code = _1.code | |
def name = _1.name | |
... | |
def consoleAccessRate = _2.consoleAccessRate | |
def nonAgentAccessRate = _2.nonAgentAccessRate | |
*/ | |
val methodsStr = subCaseClasses.iterator.zipWithIndex.flatMap { case(subfields,i) => | |
subfields.map { f => | |
Seq("def",f.name,":",f._type,"=",s"_${i+1}.${f.name}", f.optComment.map("// " + _).getOrElse("")) | |
} | |
}.toVector.printGrid(" ","\n") | |
/* | |
override def productElement(i: Int) : Any = i match { | |
case n if n < 22 => _1.productElement(i) | |
case n => _2.productElement(i - 22) | |
} | |
override def productArity : Int = 33 | |
override def productIterator = _1.productIterator ++ _2.productIterator | |
*/ | |
val productMethodsStr : String = { | |
val cases = subCaseClasses.iterator.zipWithIndex.map { case (_,i) => | |
s"case n if n < ${(i+1)*CASECLASS_MAX_FIELDS} => _${i+1}.productElement(i - ${i*CASECLASS_MAX_FIELDS})" | |
}.mkString("\n").indent(2) | |
val iterators = subCaseClasses.indices.map(i => s"_${i+1}.productIterator").mkString(" ++ ") | |
s""" | |
|override def productElement(i: Int) : Any = i match { | |
|$cases | |
| case _ => throw new IndexOutOfBoundsException | |
|} | |
|override def productArity : Int = ${fields.size} | |
|override def productIterator = $iterators | |
""".stripMargin.trim | |
} | |
/* | |
case class Licensee1( | |
licenseeId : Long, | |
code : String, | |
name : scala.Option[String], | |
shortName : scala.Option[String] = None, | |
... | |
case class Licensee2( | |
ssoProvider : scala.Option[String], | |
consoleAccess : scala.Option[Byte], | |
consoleAccessRate : Double, | |
nonAgentAccessRate : scala.Option[Double], | |
) | |
*/ | |
val subCaseClassesStr = { | |
subCaseClasses.iterator.zipWithIndex.map { case(subfields,i) => | |
CaseClassPrinter.printNormCaseClass(name + (i+1).toString,subfields) | |
}.mkString("\n") | |
} | |
// Copy method that takes all parameters | |
val bigCopyStr = { | |
/* | |
licenseeId : Long = _1.licenseeId,, | |
code : String = _1.code, | |
name : scala.Option[String] = _1.name, | |
shortName : scala.Option[String] = _1.shortName, | |
... | |
*/ | |
val copyParmsStr = subCaseClasses.iterator.zipWithIndex.flatMap { case(subfields,i) => | |
subfields.map { f => | |
Seq(f.name,":",f._type,"=",s"_${i+1}.${f.name}") | |
} | |
}.toVector.printGrid(" ",",\n") | |
val caseClassParmsStr = subCaseClasses.iterator.zipWithIndex.flatMap { case(subfields,i) => | |
subfields.map { f => | |
Seq(f.name,"=",f.name) | |
} | |
}.toVector.printGrid(" ",",\n") | |
s""" | |
|def copy( | |
|${copyParmsStr.indent(2)} | |
|) : $name = $name( | |
|${caseClassParmsStr.indent(2)} | |
|) | |
""".stripMargin.trim | |
} | |
// Apply method that takes all parameters | |
val bigApplyStr = { | |
/* | |
licenseeId : Long, | |
code : String, | |
name : scala.Option[String], | |
shortName : scala.Option[String] = None, | |
... | |
*/ | |
val allFieldsStr = CaseClassPrinter.printFieldDecls(fields) | |
/* | |
_1 = Licensee1( | |
licenseeId = licenseeId, | |
code = code, | |
... | |
), | |
2 = Licensee2( | |
ssoProvider = ssoProvider, | |
... | |
) | |
*/ | |
val applySubCaseClassesStr = subCaseClasses.iterator.zipWithIndex.map { case (subfields,i) => | |
val subFieldParms = subfields.indices.map { j => | |
val v = Vector(subfields(j).name,"=",subfields(j).name) | |
if(j == subfields.indices.last) { | |
v | |
} else { | |
v.updated(v.indices.last,v.last + ",") | |
} | |
}.printGrid(" ","\n") | |
s""" | |
|_${i+1} = $name${i+1}( | |
|${subFieldParms.indent(2)} | |
|) | |
""".stripMargin.trim | |
}.mkString(",\n") | |
s""" | |
|def apply( | |
|${allFieldsStr.indent(2)} | |
|) : $name = $name( | |
|${applySubCaseClassesStr.indent(2)} | |
|) | |
""".stripMargin.trim | |
} | |
s""" | |
|case class $name( | |
|${fieldsStr.indent(2)} | |
|) { | |
|${methodsStr.indent(2)} | |
| | |
|${bigCopyStr.indent(2)} | |
| | |
|${productMethodsStr.indent(2)} | |
|} | |
| | |
|object $name { | |
|${subCaseClassesStr.indent(2)} | |
|${bigApplyStr.indent(2)} | |
|} | |
""".stripMargin.trim | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package s_mach.codetools.bigcaseclass | |
case class CaseClassField( | |
name: String, | |
_type: String, | |
optDefault: Option[String], | |
optComment: Option[String] | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package s_mach.codetools.bigcaseclass | |
import s_mach.string._ | |
object CaseClassPrinter { | |
def print(name: String, fields: Vector[CaseClassField]) : String = { | |
if(fields.size <= CASECLASS_MAX_FIELDS) { | |
printNormCaseClass(name, fields) | |
} else { | |
BigCaseClassPrinter.print(name, fields) | |
} | |
} | |
def printNormCaseClass(name: String, fields: Vector[CaseClassField]) : String = { | |
val fieldsStr = printFieldDecls(fields) | |
s""" | |
|case class $name( | |
|${fieldsStr.indent(2)} | |
|) | |
""".stripMargin.trim | |
} | |
def printFieldDecls(fields: Vector[CaseClassField]) : String = { | |
val atLeastOneDefault = fields.exists(_.optDefault.nonEmpty) | |
fields.indices.map { j => | |
val baseFieldDecl = Vector( | |
fields(j).name, | |
":", | |
fields(j)._type | |
) ++ fields(j).optDefault.map("= " + _).toVector | |
val fieldDeclWithComma = | |
if(j == fields.indices.last) { | |
baseFieldDecl | |
} else { | |
baseFieldDecl.updated(baseFieldDecl.indices.last,baseFieldDecl.last + ",") | |
} | |
if(fields(j).optDefault.isEmpty && atLeastOneDefault) { | |
fieldDeclWithComma ++ Vector("") ++ fields(j).optComment.map("// " + _).toVector | |
} else { | |
fieldDeclWithComma ++ fields(j).optComment.map("// " + _).toVector | |
} | |
}.printGrid(" ","\n") | |
} | |
} | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package s_mach.codetools.bigcaseclass | |
import s_mach.string._ | |
object DdlToCaseClassPrinter { | |
case class Config( | |
formatTableName : String => String = { | |
import WordSplitter.Underscore | |
_.toCamelCase | |
}, | |
formatColumnName : String => String = { | |
import WordSplitter.Underscore | |
_.toCamelCase | |
}, | |
sqlToScalaTypeMap: Map[String, String] = stdSqlToScalaTypeMap | |
) | |
val scala_String = "String" | |
val scala_ArrayByte = "Array[Byte]" | |
val scala_Boolean = "Boolean" | |
val scala_Byte = "Byte" | |
val scala_Short = "Short" | |
val scala_Int = "Int" | |
val scala_Long = "Long" | |
val scala_BigInt = "BigInt" | |
val scala_Float = "Float" | |
val scala_Double = "Double" | |
val scala_BigDecimal = "BigDecimal" | |
val java_util_Date = "java.util.Date" | |
val stdSqlToScalaTypeMap : Map[String, String] = Map( | |
"char" -> scala_String, | |
"varchar" -> scala_String, | |
"tinytext" -> scala_String, | |
"text" -> scala_String, | |
"mediumtext" -> scala_String, | |
"longtext" -> scala_String, | |
"clob" -> scala_String, | |
"set" -> scala_String, | |
"enum" -> scala_String, | |
"blob" -> scala_ArrayByte, | |
"mediumblob" -> scala_ArrayByte, | |
"bit" -> scala_Boolean, | |
"tinyint(1)" -> scala_Boolean, | |
"unsigned tinyint(1)" -> scala_Boolean, | |
"tinyint" -> scala_Byte, | |
"unsigned tinyint" -> scala_Short, // scala has no concept of unsigned so need to promote | |
"smallint" -> scala_Short, | |
"unsigned smallint" -> scala_Int, // scala has no concept of unsigned so need to promote | |
"mediumint" -> scala_Int, | |
"unsigned mediumint" -> scala_Int, // scala has no concept of unsigned so need to promote | |
"int" -> scala_Int, | |
"unsigned int" -> scala_Long, // scala has no concept of unsigned so need to promote | |
"bigint" -> scala_BigInt, | |
"unsigned bigint" -> scala_BigInt, | |
"float" -> scala_Float, | |
"double" -> scala_Double, | |
"unsigned double" -> scala_Double, | |
"decimal" -> scala_BigDecimal, | |
"date" -> java_util_Date, | |
"datetime" -> java_util_Date, | |
"timestamp" -> scala_Long, | |
"time" -> java_util_Date, | |
"year" -> "Int" | |
) | |
// TODO: replace with real DDL parser | |
val parseCreateTableRegex = "(?i)CREATE TABLE [`]?(\\w+)[`]?\\s*\\((.+)\\)".r | |
val parseColumnDeclRegex = "(?i)[`]?(\\w+)[`]?\\s+(\\w+)\\s*(\\(.+?\\))?([^,]*)[,]".r | |
val parseColumnDeclFilter = "(?i)(?<=(\\s|^))(PRIMARY|KEY|CONSTRAINT|FOREIGN|USING)(?=(\\s|$))".r | |
val parseDefaultRegex = "(?i)DEFAULT (NULL|'.*?')".r | |
/** @return a case class for the given SQL DDL */ | |
def print( | |
ddl: String, | |
cfg: Config = Config() | |
) : String = { | |
import cfg._ | |
val tidyDdl = ddl.replaceAllLiterally("\n"," ").replaceAll("\\s+"," ") | |
parseCreateTableRegex.findAllMatchIn(tidyDdl).map { tblMatch => | |
val tableName = tblMatch.group(1) | |
val columns = tblMatch.group(2) | |
val fields : Vector[CaseClassField] = { | |
// Parse out column declarations and filter unintentional matches to lines like "PRIMARY KEY" | |
val columnDecls = | |
parseColumnDeclRegex | |
.findAllMatchIn(columns) | |
.filter(m => | |
parseColumnDeclFilter.findFirstIn(m.group(0)).isEmpty | |
) | |
columnDecls.zipWithIndex.map { case (columnMatch,i) => | |
val comment = columnMatch.group(0) | |
val columnName = columnMatch.group(1) | |
val rawColumnType = columnMatch.group(2) | |
val columnTypeMod = columnMatch.group(3) | |
val suffix = columnMatch.group(4) | |
val lcSuffix = suffix.toLowerCase | |
val isNullable = lcSuffix.contains("not null") == false | |
val columnType = { | |
{if(lcSuffix.contains("unsigned")) { | |
"unsigned " | |
} else { | |
"" | |
}} + | |
{if(rawColumnType.equalsIgnoreCase("tinyint") && columnTypeMod == "(1)") { | |
"tinyint(1)" | |
} else { | |
rawColumnType | |
}} | |
} | |
val baseScalaType = sqlToScalaTypeMap.getOrElse( | |
columnType, | |
throw new RuntimeException(s"Unmapped SQL type: $columnType! ${columnMatch.group(0)}") | |
) | |
// Parse the column default - some tricky logic here | |
val optDefault = { | |
parseDefaultRegex.findFirstMatchIn(suffix) match { | |
case Some(m) => Some { | |
// Translate NULL to None | |
if(m.group(1).equalsIgnoreCase("NULL")) { | |
"None" | |
} else { | |
// Strip quotes | |
val rawDefault = m.group(1).tail.init | |
// Adjust the scala value based on the baseScalaType | |
val baseDefault = | |
baseScalaType match { | |
// Strings need to be double-quoted in scala | |
case "String" => '"' + rawDefault.toString + '"' | |
// Chars need to be single-quoted in scala | |
case "Char" => s"'$rawDefault'" | |
// SQL uses 0 for false and any other value as true | |
case "Boolean" => rawDefault match { | |
case "0" => "false" | |
case _ => "true" | |
} | |
case _ => rawDefault | |
} | |
// If the column is nullable then wrap the default value in Some | |
if(isNullable) { | |
s"Some($baseDefault)" | |
} else { | |
baseDefault | |
} | |
} | |
} | |
case None => None | |
} | |
} | |
val scalaType = if(isNullable) { | |
s"Option[$baseScalaType]" | |
} else { | |
baseScalaType | |
} | |
CaseClassField( | |
name = formatTableName(columnName), | |
_type = scalaType, | |
optDefault = optDefault, | |
optComment = Some(s"$i $comment") | |
) | |
} | |
}.toVector | |
val caseClassName = formatColumnName(tableName) | |
val caseClassStr = CaseClassPrinter.print(caseClassName, fields) | |
val now = new java.util.Date() | |
s""" | |
|/** | |
| * Case class for a row in table $tableName | |
| * WARN: auto-generated using net.tstllc.codegen.DdlToCaseClassPrinter | |
| * WARN: field order MUST correspond to SQL column order | |
| * Regex for quick find/replace: | |
| * ${"(\\w+)\\s*:\\s*(.+?)(\\s*=\\s*(.+?))*[,]*\\s* // (\\d+)"} | |
| * $now | |
| **/ | |
|$caseClassStr | |
|/* Auto-generated from: | |
|$ddl | |
|*/ | |
""".stripMargin.trim | |
}.mkString("\n") | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package s_mach.codetools | |
package object bigcaseclass { | |
val CASECLASS_MAX_FIELDS = 22 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment