-
-
Save yoyama/ce83f688717719fc8ca145c3b3ff43fd to your computer and use it in GitHub Desktop.
/** | |
* Generate Case class from DataFrame.schema | |
* | |
* val df:DataFrame = ... | |
* | |
* val s2cc = new Schema2CaseClass | |
* import s2cc.implicit._ | |
* | |
* println(s2cc.schemaToCaseClass(df.schema, "MyClass")) | |
* | |
*/ | |
import org.apache.spark.sql.types._ | |
class Schema2CaseClass { | |
type TypeConverter = (DataType) => String | |
def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter):String = { | |
def genField(s:StructField):String = { | |
val f = tc(s.dataType) | |
s match { | |
case x if(x.nullable) => s" ${s.name}:Option[$f]" | |
case _ => s" ${s.name}:$f" | |
} | |
} | |
val fieldsStr = schema.map(genField).mkString(",\n ") | |
s""" | |
|case class $className ( | |
| $fieldsStr | |
|) | |
""".stripMargin | |
} | |
object implicits { | |
implicit val defaultTypeConverter:TypeConverter = (t:DataType) => { t match { | |
case _:ByteType => "Byte" | |
case _:ShortType => "Short" | |
case _:IntegerType => "Int" | |
case _:LongType => "Long" | |
case _:FloatType => "Float" | |
case _:DoubleType => "Double" | |
case _:DecimalType => "java.math.BigDecimal" | |
case _:StringType => "String" | |
case _:BinaryType => "Array[Byte]" | |
case _:BooleanType => "Boolean" | |
case _:TimestampType => "java.sql.Timestamp" | |
case _:DateType => "java.sql.Date" | |
case _:ArrayType => "scala.collection.Seq" | |
case _:MapType => "scala.collection.Map" | |
case _:StructType => "org.apache.spark.sql.Row" | |
case _ => "String" | |
}} | |
} | |
} | |
@RayTsui you write the string to a file or append to file of case classes.
I love this conversion process! I've had very nested schemas, which required me to manually run this code on different levels of nesting. Would love a recursive or other version to handle nested schemas, which I hope to contribute back unless someone beats me to it ;)
I concur with @gstaubli. Can you please share what you did for the nested schema?
Very good idea. I was also looking for some way to execute the case class creation method and found this:
import scala.tools.reflect.ToolBox
import scala.reflect.runtime.universe._
import scala.reflect.runtime.currentMirror
val df = ....
val toolbox = currentMirror.mkToolBox()
val case_class = toolbox.compile(f.schemaToCaseClass(dfschema, "YourName"))
The return type of schemaToCaseClass would have to be runtime.universe.Tree and we would use Quasiquotes
def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter) :runtime.universe.Tree= {
def genField(s:StructField):String = {
val f = tc(s.dataType)
s match {
case x if(x.nullable) => s" ${s.name}:Option[$f]"
case _ => s" ${s.name}:$f"
}
}
val fieldsStr = schema.map(genField).mkString(",\n ")
q"""
case class $className (
$fieldsStr
)"""
}
However, I was trying to apply it back to resulting dataframe and I dont see a way to do that. Sharing whatever I found, in case it helps someone
Reference - https://stackoverflow.com/questions/31054237/what-are-the-ways-to-convert-a-string-into-runnable-code
Would it be possible to create a Macro? I can't seem to be able to actually make use of the class string generated as it won't compile https://stackoverflow.com/questions/51035313/dynamically-create-case-class-from-structtype#51035313
scala> import org.apache.spark.sql.types._
import org.apache.spark.sql.types._
scala>
scala> class Schema2CaseClass {
| type TypeConverter = (DataType) => String
|
| def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter):String = {
| def genField(s:StructField):String = {
| val f = tc(s.dataType)
| s match {
| case x if(x.nullable) => s" ${s.name}:Option[$f]"
| case _ => s" ${s.name}:$f"
| }
| }
|
| val fieldsStr = schema.map(genField).mkString(",\n ")
| s"""
| |case class $className (
| | $fieldsStr
| |)
| """.stripMargin
| }
|
| object implicits {
| implicit val defaultTypeConverter:TypeConverter = (t:DataType) => { t match {
| case _:ByteType => "Byte"
| case _:ShortType => "Short"
| case _:IntegerType => "Int"
| case _:LongType => "Long"
| case _:FloatType => "Float"
| case _:DoubleType => "Double"
| case _:DecimalType => "java.math.BigDecimal"
| case _:StringType => "String"
| case _:BinaryType => "Array[Byte]"
| case _:BooleanType => "Boolean"
| case _:TimestampType => "java.sql.Timestamp"
| case _:DateType => "java.sql.Date"
| case _:ArrayType => "scala.collection.Seq"
| case _:MapType => "scala.collection.Map"
| case _:StructType => "org.apache.spark.sql.Row"
| case _ => "String"
| }}
| }
| }
<console>:12: error: not found: type DataType
type TypeConverter = (DataType) => String
^
<console>:14: error: not found: type StructType
def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter):String = {
^
<console>:15: error: not found: type StructField
def genField(s:StructField):String = {
^
<console>:32: error: not found: type DataType
implicit val defaultTypeConverter:TypeConverter = (t:DataType) => { t match {
^
<console>:33: error: not found: type ByteType
case _:ByteType => "Byte"
^
<console>:34: error: not found: type ShortType
case _:ShortType => "Short"
^
<console>:35: error: not found: type IntegerType
case _:IntegerType => "Int"
^
<console>:36: error: not found: type LongType
case _:LongType => "Long"
^
<console>:37: error: not found: type FloatType
case _:FloatType => "Float"
^
<console>:38: error: not found: type DoubleType
case _:DoubleType => "Double"
^
<console>:39: error: not found: type DecimalType
case _:DecimalType => "java.math.BigDecimal"
^
<console>:40: error: not found: type StringType
case _:StringType => "String"
^
<console>:41: error: not found: type BinaryType
case _:BinaryType => "Array[Byte]"
^
<console>:42: error: not found: type BooleanType
case _:BooleanType => "Boolean"
^
<console>:43: error: not found: type TimestampType
case _:TimestampType => "java.sql.Timestamp"
^
<console>:44: error: not found: type DateType
case _:DateType => "java.sql.Date"
^
<console>:45: error: not found: type ArrayType
case _:ArrayType => "scala.collection.Seq"
^
<console>:46: error: not found: type MapType
case _:MapType => "scala.collection.Map"
^
<console>:47: error: not found: type StructType
case _:StructType => "org.apache.spark.sql.Row"
^
If you have this mistake use in console :paste
, in this way it's works for me.
I concur with @gstaubli. Can you please share what you did for the nested schema?
+1
This is really helpful, but here's an improvement. In defaultTypeConverter
, change the ArrayType
case to
case _: ArrayType => {
val e = t match { case ArrayType(elementType, _) => elementType }
s"Seq[${defaultTypeConverter(e)}]"
}
Perhaps I did something wrong here, but was unable to get this working with the implicits statement
I had to explicitly pass it in as below, for anyone who had the same issue and received an identifier expected but 'implicit' found.
println(s2cc.schemaToCaseClass(schema, "MyclassName")(s2cc.implicits.defaultTypeConverter))
That said, super helpful, thanks!
Thanks for sharing this. As small enhancement could be, if there are nested StructType in a schema. I have tried incorporating that scenario below:
import java.io.FileWriter
import org.apache.spark.sql.types._
class SchemaToCaseClassWriter(fileWriter: FileWriter) {
type TypeConverter = DataType => String
def write(schema: StructType, className: String): Unit = {
run(schema, className)
fileWriter.close()
}
private def run(schema: StructType, className: String): Unit = {
def genField(field: StructField): String = {
val converter = defaultTypeConverter(field.name)
val dataType = converter(field.dataType)
field match {
case x if x.nullable => s" ${field.name}:Option[$dataType]"
case _ => s" ${field.name}:$dataType"
}
}
val fieldsStr = schema.map(genField).mkString(",\n ")
val schemaClass =
s"""case class $className (
| $fieldsStr
|)
|
|""".stripMargin
fileWriter.write(schemaClass)
}
private def defaultTypeConverter(colName: String): TypeConverter = {
val converter: TypeConverter = {
case _: ByteType => "Byte"
case _: ShortType => "Short"
case _: IntegerType => "Int"
case _: LongType => "Long"
case _: FloatType => "Float"
case _: DoubleType => "Double"
case _: DecimalType => "java.math.BigDecimal"
case _: StringType => "String"
case _: BinaryType => "Array[Byte]"
case _: BooleanType => "Boolean"
case _: TimestampType => "java.sql.Timestamp"
case _: DateType => "java.sql.Date"
case t: ArrayType =>
val e = t match {
case ArrayType(elementType, _) => elementType
}
s"Seq[${defaultTypeConverter(colName)(e)}]"
case _: MapType => "scala.collection.Map"
case t: StructType =>
run(t, colName.capitalize)
colName.capitalize
case _ => "String"
}
converter
}
Schema with nested structure is having a struct with the same name at different levels, then 2 class with the same name will be created. This will break the schema when used. I think we need to use package name to handle that . Any other alternatives ?
How to use the resultant string as case class any example pls
/*we can use below code directly it will return string instead of writing into file we can get string in a variable */
import org.apache.spark.sql.types._
class SchemaToCaseClassWriter {
type TypeConverter = DataType => String
def write(schema: StructType, className: String): String = {
run(schema, className)
}
private def run(schema: StructType, className: String): String = {
def genField(field: StructField): String = {
val converter = defaultTypeConverter(field.name)
val dataType = converter(field.dataType)
field match {
case x if x.nullable => s" ${field.name}: Option[$dataType]"
case _ => s" ${field.name}: $dataType"
}
}
val fieldsStr = schema.map(genField).mkString(",\n ")
val schemaClass =
s"""case class $className (
| $fieldsStr
|)
|
|""".stripMargin
schemaClass
}
private def defaultTypeConverter(colName: String): TypeConverter = {
val converter: TypeConverter = {
case _: ByteType => "Byte"
case _: ShortType => "Short"
case _: IntegerType => "Int"
case _: LongType => "Long"
case _: FloatType => "Float"
case _: DoubleType => "Double"
case _: DecimalType => "java.math.BigDecimal"
case _: StringType => "String"
case _: BinaryType => "Array[Byte]"
case _: BooleanType => "Boolean"
case _: TimestampType => "java.sql.Timestamp"
case _: DateType => "java.sql.Date"
case t: ArrayType =>
val e = t match {
case ArrayType(elementType, _) => elementType
}
s"Seq[${defaultTypeConverter(colName)(e)}]"
case _: MapType => "scala.collection.Map"
case t: StructType =>
run(t, colName.capitalize)
colName.capitalize
case _ => "String"
}
converter
}
}
val writer = new SchemaToCaseClassWriter()
val schema = // Your StructType schema
val className = "MyClass"
val caseClassString = writer.write(schema, className)
println(caseClassString) // Output the generated case class string
After getting the string of executable code of case class, how to execute the string? Scala reflect or something else?