Last active
June 1, 2020 15:03
-
-
Save pathikrit/44f13bb9492cc1827b208f6a9862da33 to your computer and use it in GitHub Desktop.
Spark utils to ship data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.nio.charset.{ Charset, StandardCharsets } | |
import org.apache.spark.sql._ | |
import org.apache.spark.sql.types._ | |
object SparkDataLoad { | |
def fromCsv[A : Encoder]( | |
path: Set[String], | |
encoding: Charset = StandardCharsets.UTF_8, | |
useHeader: Boolean = false, | |
delimiter: Char = '|', | |
quote: Char = '"', | |
escape: Char = '\\', | |
skipLinesStartingWith: Option[Char] = None, | |
dateFormat: String = "yyyyMMdd", | |
timestampFormat: String = "yyyy-MM-dd'T'HH:mm:ss.SSSXXX", | |
representEmptyValueAs: String = "", | |
treatAsNull: String = "", | |
treatAsNaN: String = "NaN", | |
treatAsPositiveInf: String = "Inf", | |
treatAsNegativeInf: String = "-Inf", | |
ignoreLeadingWhiteSpace: Boolean = true, | |
ignoreTrailingWhiteSpace: Boolean = true, | |
inputFileNameColumn: String = "_source_file" | |
)(implicit spark: SparkSession): DataFrame = { | |
spark.read | |
.option("mode", "PERMISSIVE") | |
.option("encoding", encoding.name()) | |
.option("header", useHeader) | |
.option("delimiter", delimiter.toString) | |
.option("quote", quote.toString) | |
.option("escape", escape.toString) | |
.option("dateFormat", dateFormat) | |
.option("timestampFormat", timestampFormat) | |
.option("emptyValue", representEmptyValueAs) | |
.option("nullValue", treatAsNull) | |
.option("nanValue", treatAsNaN) | |
.option("positiveInf", treatAsPositiveInf) | |
.option("negativeInf", treatAsNegativeInf) | |
.option("comment", skipLinesStartingWith.map(_.toString).orNull) | |
.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace) | |
.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace) | |
.schema(implicitly[Encoder[A]].schema) | |
.csv(path.toSeq: _*) | |
.withColumn(inputFileNameColumn, input_file_name()) | |
} | |
def readFromSnowflake( | |
account: String = "*****.us-east-1.snowflakecomputing.com", | |
username: String = "dev", | |
password: String = "***************", | |
warehouse: String = "dev", | |
database: String = "dev", | |
table: String // Can be either a SELECT statement OR a table name | |
)(implicit spark: SparkSession): DataFrame = | |
spark.read | |
.format("net.snowflake.spark.snowflake") | |
.options( | |
Map( | |
"sfUrl" -> account, | |
"sfUser" -> username, | |
"sfPassword" -> password, | |
"sfDatabase" -> database, | |
"sfWarehouse" -> warehouse, | |
(if (table.toUpperCase.contains("SELECT ")) "query" else "dbtable") -> table | |
) | |
) | |
.load() | |
def toSnowflake( | |
account: String = "*****.us-east-1.snowflakecomputing.com", | |
username: String = "dev", | |
password: String = "***************", | |
warehouse: String = "dev", | |
database: String = "dev", | |
schema: String, | |
table: String, | |
clusterBy: Seq[String] = Nil, | |
dataset: Dataset[_], | |
isAppend: Boolean = false | |
): Unit = { | |
def toSnowflakeColumn(field: StructField): String = { | |
val col = field.dataType match { | |
case _: BooleanType => "BOOLEAN" | |
case _: ByteType | _: ShortType | _: IntegerType | _: LongType => "INTEGER" | |
case _: DecimalType | _: FloatType | _: DoubleType => "REAL" | |
case _: DateType => "DATE" | |
case _: TimestampType => "TIMESTAMP_TZ" | |
case _: StringType | _: VarcharType => "TEXT" | |
case _: ArrayType => "ARRAY" | |
case _ => throw new UnsupportedOperationException(s"Unsupported field = ${field}") | |
} | |
s"${field.name.toLowerCase} ${if (field.nullable) s"$col" else s"$col NOT NULL"}" | |
} | |
val tempTable = s"${table}_stage" | |
val clusterStmt = if (clusterBy.isEmpty) "" else clusterBy.mkString(" CLUSTER BY(", ", ", ")"); | |
val createTable = dataset.schema.fields | |
.map(toSnowflakeColumn) | |
.mkString(s"CREATE OR REPLACE TRANSIENT TABLE $schema.$tempTable(\n\t", ",\n\t", s") $clusterStmt") | |
val preActions = Seq( | |
s"USE DATABASE $db", | |
s"USE WAREHOUSE $warehouse", | |
s"CREATE SCHEMA IF NOT EXISTS $schema", | |
s"USE SCHEMA $schema", | |
createTable | |
) | |
val postActions = Seq( | |
s"DROP TABLE IF EXISTS $schema.$table", | |
s"ALTER TABLE $schema.$tempTable RENAME TO $table" | |
) | |
println(((preActions :+ s"COPY DATAFRAME TO ${schema}.${tempTable}") ++ postActions).mkString("", ";\n\n", ";")) | |
dataset | |
.write | |
.format("snowflake") | |
.options(Map( | |
"sfUrl" -> account, | |
"sfUser" -> username, | |
"sfPassword" -> password, | |
"sfDatabase" -> database, | |
"sfWarehouse" -> warehouse, | |
"dbtable" -> table, | |
"preactions" -> preActions.mkString("", ";", ";"), | |
"postactions" -> postActions.mkString("", ";", ";") | |
)) | |
.mode(if (isAppend) SaveMode.Append else SaveMode.Overwrite) | |
.save() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment