Last active
May 2, 2018 17:40
-
-
Save EricLondon/1094ba9f0492631ae299f001bfd73e9c to your computer and use it in GitHub Desktop.
Spark Scala Snippets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// Custom UDF | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// define method | |
def isNumeric: (String => Boolean) = { | |
case null => false | |
case value => value forall Character.isDigit | |
} | |
val ISNUMERIC = udf(isNumeric) | |
// register UDF with Spark | |
sqlContext.udf.register("ISNUMERIC", isNumeric) | |
// usage via Spark API/code: | |
import UDF._ | |
df.withColumn("newColumn", | |
when(ISNUMERIC(col("testColumn")), lit("testColumn is numeric")) | |
.otherwise(lit("testColumn is not numeric")) | |
) | |
// usage via select expression | |
df.selectExpr("ISNUMERIC(testColumn) as newColumn").show() | |
// usage via callUDF | |
df.select(callUDF("ISNUMERIC", col("someColumn")).as("newColumn")).show() | |
// Example of UDF using curried function: | |
val addP = (p: Int) => udf( (x: Int) => x + p ) | |
df.withColumn("col3", addP(100)($"col2")).show | |
// output: | |
+----+----+----+ | |
|col1|col2|col3| | |
+----+----+----+ | |
| A| 1| 101| | |
| B| 2| 102| | |
| C| 3| 103| | |
+----+----+----+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// Selecting a list of columns from a dataframe | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
val columns = Seq("col1", "col2", "col3", "col4", "col5") | |
val result = df.select(columns.head, columns.tail: _*) | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// showing defined UDF functions | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
spark.catalog.listFunctions.show(1000, false) | |
spark.catalog.listFunctions.filter(_.name.contains("GET")).show(1000, false) | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// show tables | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
sqlContext.sql("show tables").show(1000, false) | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// track time elapsed | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
val startTime = System.currentTimeMillis() | |
// additional code here | |
val endTime = System.currentTimeMillis() | |
val elapsedSeconds = (endTime - startTime) / 1000L | |
println(s"Code completed in ${elapsedSeconds} seconds") | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// inner join on two dataframes (when join columns have same name) | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
val df = SpecDataFactory.loadData( | |
spark.sqlContext, | |
""" | |
id|col1 | |
1|col1a | |
2|col1b | |
3|col1c | |
4|col1d | |
""".stripMargin | |
) | |
val df2 = SpecDataFactory.loadData( | |
spark.sqlContext, | |
""" | |
id|col2 | |
1|col2a | |
2|col2b | |
3|col2c | |
""".stripMargin | |
) | |
df.join(df2, Seq("id"), "inner").orderBy(col("id")).show(100, false) | |
// output: | |
+---+-----+-----+ | |
|id |col1 |col2 | | |
+---+-----+-----+ | |
|1 |col1a|col2a| | |
|2 |col1b|col2b| | |
|3 |col1c|col2c| | |
+---+-----+-----+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// left join on foreign key (when join columns are not the same name) | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
val df = SpecDataFactory.loadData( | |
spark.sqlContext, | |
""" | |
id1|col1 | |
1|col1a | |
2|col1b | |
3|col1c | |
4|col1d | |
""".stripMargin | |
) | |
val df2 = SpecDataFactory.loadData( | |
spark.sqlContext, | |
""" | |
id2|col2 | |
1|col2a | |
2|col2b | |
3|col2c | |
""".stripMargin | |
) | |
df.join(df2, df("id1") === df2("id2"), "left").orderBy(col("id1")).show(100, false) | |
// output: | |
+---+-----+----+-----+ | |
|id1|col1 |id2 |col2 | | |
+---+-----+----+-----+ | |
|1 |col1a|1 |col2a| | |
|2 |col1b|2 |col2b| | |
|3 |col1c|3 |col2c| | |
|4 |col1d|null|null | | |
+---+-----+----+-----+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// calculating age | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
import spark.implicits._ | |
val df = Seq( | |
("1980-09-08", "2018-04-03"), | |
("1980-09-08", "2018-09-07"), | |
("1980-09-08", "2018-09-08") | |
).toDF("birth", "current") | |
df.selectExpr( | |
"floor(months_between(current, birth)/12) as age" | |
).show(100, false) | |
// output: | |
+---+ | |
|age| | |
+---+ | |
|37 | | |
|37 | | |
|38 | | |
+---+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// JSON string to DataFrame | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
val jsonStr = """ | |
{ | |
"something": { | |
"key": "foo", | |
"value": "bar" | |
} | |
} | |
""" | |
val rdd = spark.sparkContext.parallelize(Seq(jsonStr)) | |
val df = spark.sqlContext.read.json(rdd) | |
df.show(100, false) | |
// output | |
+---------+ | |
|something| | |
+---------+ | |
|[foo,bar]| | |
+---------+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// concatenate with string separator | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
import spark.implicits._ | |
Seq( | |
("some", "other", "thing"), | |
("foo", null, "bar") | |
) | |
.toDF("col1", "col2", "col3") | |
.withColumn( | |
"concatenated", | |
concat_ws("_", | |
col("col1"), | |
col("col2"), | |
col("col3") | |
) | |
) | |
.select(col("concatenated")) | |
.show(100, false) | |
// output | |
+----------------+ | |
|concatenated | | |
+----------------+ | |
|some_other_thing| | |
|foo_bar | | |
+----------------+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// retrieving a value from a DataFrame column | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
import spark.implicits._ | |
val df = Seq( | |
("red"), | |
("blue") | |
) | |
.toDF("color") | |
val firstColor = df.take(1)(0).getAs[String]("color") | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// loading a DataFrame from CSV stored on S3 | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
val s3aPath = "s3a://bucket/path/file.csv" | |
val df = sqlContext.read | |
.options( | |
Map( | |
"header"->"true", | |
"delimiter"-> "|", | |
"escape"->"\\", | |
"quote"->"\"", | |
"nullValue"->null | |
) | |
) | |
.csv(s3aPath) | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// Adding days to a date | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
import java.util.Calendar | |
import java.text.SimpleDateFormat | |
val dateString = "2015-01-01" | |
val dateParsed = new SimpleDateFormat("yyyy-MM-dd").parse(dateString) | |
val cal = Calendar.getInstance() | |
cal.setTime(dateParsed) | |
cal.add(Calendar.DATE, 15) | |
println(cal.getTime()) | |
// Fri Jan 16 00:00:00 EST 2015 | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// Extract text using regular expression | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
import spark.implicits._ | |
val df = List( | |
("12345"), | |
("67abcde"), | |
("fghij89"), | |
(null) | |
).toDF("col1") | |
.withColumn("digits", regexp_extract(col("col1"), "[0-9]+", 0)) | |
.show(1000, false) | |
// output: | |
+-------+------+ | |
|col1 |digits| | |
+-------+------+ | |
|12345 |12345 | | |
|67abcde|67 | | |
|fghij89|89 | | |
|null |null | | |
+-------+------+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// Join with null safe equality operator | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
val df1 = SpecDataFactory.loadData( | |
spark.sqlContext, | |
""" | |
id1|col1 | |
1|a | |
2|b | |
3|c | |
null|d | |
""".stripMargin | |
) | |
val df2 = SpecDataFactory.loadData( | |
spark.sqlContext, | |
""" | |
id2|col2 | |
1|e | |
2|f | |
4|g | |
null|h | |
""".stripMargin | |
) | |
df1.join(df2, df1("id1") <=> df2("id2")).show(100, false) | |
// output: | |
+----+----+----+----+ | |
|id1 |col1|id2 |col2| | |
+----+----+----+----+ | |
|1 |a |1 |e | | |
|2 |b |2 |f | | |
|null|d |null|h | | |
+----+----+----+----+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// window, partitionBy, num_number example | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
import org.apache.spark.sql.expressions.Window | |
val df = SpecDataFactory.loadData( | |
spark.sqlContext, | |
""" | |
id|line_num|col1 | |
1|6|a | |
1|5|b | |
1|4|c | |
2|3|d | |
2|2|e | |
2|1|f | |
""".stripMargin | |
) | |
val w = Window.partitionBy(col("id")).orderBy("line_num") | |
df | |
.withColumn("rn", row_number.over(w)) | |
.filter(col("rn").equalTo(1)) | |
.drop("rn") | |
.show(100, false) | |
// output: | |
+---+--------+----+ | |
|id |line_num|col1| | |
+---+--------+----+ | |
|1 |4 |c | | |
|2 |1 |f | | |
+---+--------+----+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// window, partitionBy, num_number, min/max example | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
import org.apache.spark.sql.expressions.Window | |
val df = SpecDataFactory.loadData( | |
spark.sqlContext, | |
""" | |
id|col1 | |
1|6 | |
1|5 | |
1|4 | |
2|3 | |
2|2 | |
2|1 | |
""".stripMargin | |
) | |
val w = Window.partitionBy(col("id")).orderBy(col("id")) | |
df | |
.withColumn("maxVal", max(col("col1")).over(w)) | |
.withColumn("minVal", min(col("col1")).over(w)) | |
.withColumn("rn", row_number.over(w)) | |
.filter(col("rn").equalTo(1)) | |
.drop("rn", "col1") | |
.show(100, false) | |
// output: | |
+---+------+------+ | |
|id |maxVal|minVal| | |
+---+------+------+ | |
|1 |6 |4 | | |
|2 |3 |1 | | |
+---+------+------+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// Union a list of DataFrames | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
import spark.implicits._ | |
val df1 = Seq( | |
(1, "a"), | |
(2, "b") | |
).toDF("id", "col1") | |
val df2 = Seq( | |
(3, "c"), | |
(4, "d") | |
).toDF("id", "col1") | |
val df3 = Seq( | |
(5, "e"), | |
(6, "f") | |
).toDF("id", "col1") | |
val dfs = List(df1, df2, df3) | |
val unioned = dfs.reduce(_ union _) | |
unioned.show(100, false) | |
// output: | |
+---+----+ | |
|id |col1| | |
+---+----+ | |
|1 |a | | |
|2 |b | | |
|3 |c | | |
|4 |d | | |
|5 |e | | |
|6 |f | | |
+---+----+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// pivot example | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
import spark.implicits._ | |
val df = Seq( | |
("1", "red"), | |
("2", "red"), | |
("3", "blue") | |
).toDF("member_id", "colors") | |
df.show(100) | |
df | |
.groupBy("member_id") | |
.pivot("colors") | |
.agg( | |
first("colors") | |
) | |
.orderBy("member_id") | |
.show(100, false) | |
// before | |
+---------+------+ | |
|member_id|colors| | |
+---------+------+ | |
|1 |red | | |
|2 |red | | |
|3 |blue | | |
+---------+------+ | |
// after: | |
+---------+----+----+ | |
|member_id|blue|red | | |
+---------+----+----+ | |
|1 |null|red | | |
|2 |null|red | | |
|3 |blue|null| | |
+---------+----+----+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// unpivot example | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
import com.arcadia.notebook.sparkext.Implicits._ | |
val df = SpecDataFactory.loadData( | |
spark.sqlContext, | |
""" | |
id|alpha1|alpha2|alpha3|beta1|beta2|gamma1|gamma2|gamma3 | |
1|A1|A2|A3|B1|B2|C1|C2|C3 | |
2|A4|A5|A6|B3|B4|C4|C5|C6 | |
""".stripMargin | |
) | |
df.unpivot( | |
List("aVal", "bVal", "cVal"), | |
List("aName", "bName", "cName"), | |
List( | |
List("alpha1", "alpha2", "alpha3"), | |
List("beta1", "beta2"), | |
List("gamma1", "gamma2", "gamma3") | |
) | |
).show(100, false) | |
// output: | |
+---+------+----+-----+----+------+----+ | |
|id |aName |aVal|bName|bVal|cName |cVal| | |
+---+------+----+-----+----+------+----+ | |
|1 |alpha1|A1 |beta1|B1 |gamma1|C1 | | |
|1 |alpha2|A2 |beta2|B2 |gamma2|C2 | | |
|1 |alpha3|A3 | | |gamma3|C3 | | |
|2 |alpha1|A4 |beta1|B3 |gamma1|C4 | | |
|2 |alpha2|A5 |beta2|B4 |gamma2|C5 | | |
|2 |alpha3|A6 | | |gamma3|C6 | | |
+---+------+----+-----+----+------+----+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// Some/None/Option example | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
def someMethod(entityName: Option[String] = None): Unit = { | |
entityName match { | |
case Some(name) => println(s"entityName: ${name}") | |
case None => println("entityName: None") | |
case _ => println("Unknown case") | |
} | |
} | |
someMethod(Some("thing")) | |
someMethod(None) | |
someMethod() | |
// output: | |
entityName: thing | |
entityName: None | |
entityName: None | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// empty/null coalesce method examples | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
def emptyCoalesce(in: Column, alternative: Column) = { | |
when(in.isNotNull && length(in) =!= 0, in).otherwise(alternative) | |
} | |
def coalesceWithEmpty(columns: Column*): Column = { | |
val trimmedColumns = columns.map(column => | |
column match { | |
case null => null | |
case _ => { | |
val trimmed = trim(column) | |
when(length(trimmed).eqNullSafe(0), null).otherwise(trimmed) | |
} | |
} | |
) | |
coalesce(trimmedColumns: _*) | |
} | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// groupBy/agg example | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
val df = SpecDataFactory.loadData( | |
spark.sqlContext, | |
""" | |
id|col1 | |
1|1 | |
1|2 | |
1|3 | |
2|4 | |
2|5 | |
2|6 | |
3|7 | |
3|8 | |
""".stripMargin | |
) | |
df | |
.groupBy("id") | |
.agg( | |
count("*").as("countX"), | |
sum("col1").as("sumCol1"), | |
min("col1").as("minCol1"), | |
max("col1").as("maxCol1"), | |
collect_list("col1").as("listCol1") | |
) | |
.orderBy("id") | |
.show(100, false) | |
// output: | |
+---+------+-------+-------+-------+---------+ | |
|id |countX|sumCol1|minCol1|maxCol1|listCol1 | | |
+---+------+-------+-------+-------+---------+ | |
|1 |3 |6 |1 |3 |[1, 2, 3]| | |
|2 |3 |15 |4 |6 |[4, 5, 6]| | |
|3 |2 |15 |7 |8 |[7, 8] | | |
+---+------+-------+-------+-------+---------+ | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// output spark execution plan | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
println(df.rdd.toDebugString) | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// create or replace temp view | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
df.createOrReplaceTempView("tableName") | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// Writing a DataFrame to CSV in S3 | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
val defaultWriteOptions: Map[String,String] = | |
Map[String,String]( | |
"header" -> "true", | |
"compression" -> "gzip", | |
"quoteAll" -> "true", | |
"escape" -> "\"") | |
df | |
.write | |
.mode("Overwrite") | |
.options(defaultWriteOptions) | |
.csv(fileName.replaceAllLiterally("s3://", "s3a://")) | |
// Examples using CommonNotebookApiImpl: | |
CommonNotebookApiImpl.writeDataFrameToFile("s3a://bucket/path/name.csv", df) | |
CommonNotebookApiImpl.writeDataFrameToFile("s3a://bucket/path/name.csv", df, options, maxWritePartitions) | |
// NOTE: maxWritePartitions set to 1 will write to a single CSV file in S3. This can hurt concurrent performance | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
// Using foldLeft to iterate and apply a column method on a DataFrame | |
//////////////////////////////////////////////////////////////////////////////////////////////////// | |
val domainColumns = df.columns.filter(_.startsWith("domain__")) | |
val updatedDf = domainColumns.foldLeft(df)({ case (df2, domainColumn) => df2.drop(domainColumn) }) | |
//////////////////////////////////////////////////////////////////////////////////////////////////// |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment