Skip to content

Instantly share code, notes, and snippets.

@yoyama
Created January 20, 2017 07:36
Show Gist options
  • Save yoyama/ce83f688717719fc8ca145c3b3ff43fd to your computer and use it in GitHub Desktop.
Save yoyama/ce83f688717719fc8ca145c3b3ff43fd to your computer and use it in GitHub Desktop.
Generate case class from spark DataFrame/Dataset schema.
/**
* Generate Case class from DataFrame.schema
*
* val df:DataFrame = ...
*
* val s2cc = new Schema2CaseClass
* import s2cc.implicit._
*
* println(s2cc.schemaToCaseClass(df.schema, "MyClass"))
*
*/
import org.apache.spark.sql.types._
class Schema2CaseClass {
type TypeConverter = (DataType) => String
def schemaToCaseClass(schema:StructType, className:String)(implicit tc:TypeConverter):String = {
def genField(s:StructField):String = {
val f = tc(s.dataType)
s match {
case x if(x.nullable) => s" ${s.name}:Option[$f]"
case _ => s" ${s.name}:$f"
}
}
val fieldsStr = schema.map(genField).mkString(",\n ")
s"""
|case class $className (
| $fieldsStr
|)
""".stripMargin
}
object implicits {
implicit val defaultTypeConverter:TypeConverter = (t:DataType) => { t match {
case _:ByteType => "Byte"
case _:ShortType => "Short"
case _:IntegerType => "Int"
case _:LongType => "Long"
case _:FloatType => "Float"
case _:DoubleType => "Double"
case _:DecimalType => "java.math.BigDecimal"
case _:StringType => "String"
case _:BinaryType => "Array[Byte]"
case _:BooleanType => "Boolean"
case _:TimestampType => "java.sql.Timestamp"
case _:DateType => "java.sql.Date"
case _:ArrayType => "scala.collection.Seq"
case _:MapType => "scala.collection.Map"
case _:StructType => "org.apache.spark.sql.Row"
case _ => "String"
}}
}
}
@zpappa
Copy link

zpappa commented Apr 26, 2020

Perhaps I did something wrong here, but was unable to get this working with the implicits statement
I had to explicitly pass it in as below, for anyone who had the same issue and received an identifier expected but 'implicit' found.

println(s2cc.schemaToCaseClass(schema, "MyclassName")(s2cc.implicits.defaultTypeConverter))

That said, super helpful, thanks!

@JituS
Copy link

JituS commented Apr 13, 2021

Thanks for sharing this. As small enhancement could be, if there are nested StructType in a schema. I have tried incorporating that scenario below:

import java.io.FileWriter

import org.apache.spark.sql.types._

class SchemaToCaseClassWriter(fileWriter: FileWriter) {
  type TypeConverter = DataType => String

  def write(schema: StructType, className: String): Unit = {
    run(schema, className)
    fileWriter.close()
  }

  private def run(schema: StructType, className: String): Unit = {
    def genField(field: StructField): String = {
      val converter = defaultTypeConverter(field.name)
      val dataType = converter(field.dataType)
      field match {
        case x if x.nullable => s"  ${field.name}:Option[$dataType]"
        case _ => s"  ${field.name}:$dataType"
      }
    }

    val fieldsStr = schema.map(genField).mkString(",\n  ")
    val schemaClass =
      s"""case class $className (
         |  $fieldsStr
         |)
         |
         |""".stripMargin
    fileWriter.write(schemaClass)
  }

  private def defaultTypeConverter(colName: String): TypeConverter = {
    val converter: TypeConverter = {
      case _: ByteType => "Byte"
      case _: ShortType => "Short"
      case _: IntegerType => "Int"
      case _: LongType => "Long"
      case _: FloatType => "Float"
      case _: DoubleType => "Double"
      case _: DecimalType => "java.math.BigDecimal"
      case _: StringType => "String"
      case _: BinaryType => "Array[Byte]"
      case _: BooleanType => "Boolean"
      case _: TimestampType => "java.sql.Timestamp"
      case _: DateType => "java.sql.Date"
      case t: ArrayType =>
        val e = t match {
          case ArrayType(elementType, _) => elementType
        }
        s"Seq[${defaultTypeConverter(colName)(e)}]"
      case _: MapType => "scala.collection.Map"
      case t: StructType =>
        run(t, colName.capitalize)
        colName.capitalize
      case _ => "String"
    }
    converter
  }

@maxmithun
Copy link

Schema with nested structure is having a struct with the same name at different levels, then 2 class with the same name will be created. This will break the schema when used. I think we need to use package name to handle that . Any other alternatives ?

@srimunugoti
Copy link

How to use the resultant string as case class any example pls

@7873737376
Copy link

7873737376 commented Feb 8, 2024

/*we can use below code directly it will return string instead of writing into file we can get string in a variable */

import org.apache.spark.sql.types._

class SchemaToCaseClassWriter {
type TypeConverter = DataType => String

def write(schema: StructType, className: String): String = {
run(schema, className)
}

private def run(schema: StructType, className: String): String = {
def genField(field: StructField): String = {
val converter = defaultTypeConverter(field.name)
val dataType = converter(field.dataType)
field match {
case x if x.nullable => s" ${field.name}: Option[$dataType]"
case _ => s" ${field.name}: $dataType"
}
}

val fieldsStr = schema.map(genField).mkString(",\n  ")
val schemaClass =
  s"""case class $className (
     |  $fieldsStr
     |)
     |
     |""".stripMargin
schemaClass

}

private def defaultTypeConverter(colName: String): TypeConverter = {
val converter: TypeConverter = {
case _: ByteType => "Byte"
case _: ShortType => "Short"
case _: IntegerType => "Int"
case _: LongType => "Long"
case _: FloatType => "Float"
case _: DoubleType => "Double"
case _: DecimalType => "java.math.BigDecimal"
case _: StringType => "String"
case _: BinaryType => "Array[Byte]"
case _: BooleanType => "Boolean"
case _: TimestampType => "java.sql.Timestamp"
case _: DateType => "java.sql.Date"
case t: ArrayType =>
val e = t match {
case ArrayType(elementType, _) => elementType
}
s"Seq[${defaultTypeConverter(colName)(e)}]"
case _: MapType => "scala.collection.Map"
case t: StructType =>
run(t, colName.capitalize)
colName.capitalize
case _ => "String"
}
converter
}
}

val writer = new SchemaToCaseClassWriter()
val schema = // Your StructType schema
val className = "MyClass"
val caseClassString = writer.write(schema, className)
println(caseClassString) // Output the generated case class string

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment