Skip to content

Instantly share code, notes, and snippets.

@holdenk
Created April 1, 2024 17:49
Show Gist options
  • Save holdenk/0f9660bcbd9e63aaff904f15d3439db1 to your computer and use it in GitHub Desktop.
Save holdenk/0f9660bcbd9e63aaff904f15d3439db1 to your computer and use it in GitHub Desktop.
val df = sc.parallelize(Seq((1, "John Doe", 21),
(2, "Jane Doe", 21),
(3, "Timbit DeSimoneCarolyn", 21),
(1, "John Doe", 20))).toDF("id", "name", "age")
import org.apache.spark.sql.functions.udf
val dcount = sc.longAccumulator
val fun = udf((a: Int) => {
dcount.add(1)
""})
val dudf = spark.udf.register("udf", fun)
import spark.implicits._
// "Sad" double eval path
val initial_filtered = df.filter($"age" >= 21)
val with_udf_column = initial_filtered.withColumn("magic", dudf($"id"))
val final_filtered = with_udf_column.filter($"magic".isNotNull)
// We need to do collect instead of count because count effectively ignores the data.
final_filtered.collect()
dcount // This returns 6 meaning udf has been evaluated twice for each record, in this case that's
// ok but if dudf was an "actual" udf with expensive business logic, or we had a PCRE we bublled up
// this double evaluation would be bad.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment