Skip to content

Instantly share code, notes, and snippets.

@saswata-dutta
Last active April 5, 2022 03:05
Show Gist options
  • Save saswata-dutta/bdc24a95b18f58fef809e26af4e198a0 to your computer and use it in GitHub Desktop.
Save saswata-dutta/bdc24a95b18f58fef809e26af4e198a0 to your computer and use it in GitHub Desktop.
Spark cumulative sum with gaps
import spark.implicits._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS
// need to set the shuffle partitions correctly based on data sizes and distribution
// to spread evenly across executors; default is 200
spark.conf.set(SHUFFLE_PARTITIONS.key, 100)
val all_dates = (1 to 5).toDF("date")
val all_events = (
List.fill(5)(("gmail", 1)) ++ List.fill(3)(("gmail", 3)) ++ List.fill(1)(("gmail", 5)) ++
List.fill(2)(("yahoo", 1)) ++ List.fill(3)(("yahoo", 5)) ++
List.fill(1)(("outlook", 5))
).toDF("key", "date")
// might need salting for only popular keys like gmail here to evenly spread it out across executors
// cant salt blindly need some pre analysis once : say keys that happen more than mill times
// and decide buckets accordingly so all 'new' keys equally bucketised
val skewed_keys = array(lit("gmail"), lit("yahoo"))
val salted_events = all_events.withColumn("salt",
when(array_contains(skewed_keys, col("key")), (rand * 10).cast(IntegerType)).
otherwise(lit(0))
).withColumn("key_salt", concat(col("key"), lit("__"), col("salt"))).drop("salt", "key")
/*
+----+----------+
|date| key_salt|
+----+----------+
| 1| gmail__7|
| 1| gmail__8|
| 1| gmail__0|
| 1| gmail__5|
| 1| gmail__0|
| 3| gmail__6|
| 3| gmail__2|
| 3| gmail__6|
| 5| gmail__7|
| 1| yahoo__0|
| 1| yahoo__4|
| 5| yahoo__8|
| 5| yahoo__1|
| 5| yahoo__8|
| 5|outlook__0|
+----+----------+
*/
val date_key_count =
salted_events.groupBy("date", "key_salt").
count().
withColumn("key", element_at(split(col("key_salt"), "__"), 1)).
repartition(YYY, col("date"), col("key")). // repart to smaller num of tasks
groupBy("date", "key").agg(sum("count").as("count"))
/*
+----+-------+-----+
|date| key|count|
+----+-------+-----+
| 5|outlook| 1|
| 5| yahoo| 3|
| 3| gmail| 3|
| 1| gmail| 5|
| 5| gmail| 1|
| 1| yahoo| 2|
+----+-------+-----+
*/
val all_uniq_keys = date_key_count.select("key").distinct
val all_date_key = all_dates.crossJoin(all_uniq_keys)
val all_date_key_count = all_date_key.join(events_date_key_count, Seq("date", "key"), "left")
/*
+----+-------+-----+
|date| key|count|
+----+-------+-----+
| 1| gmail| 5|
| 2| gmail| null|
| 3| gmail| 3|
| 4| gmail| null|
| 5| gmail| 1|
| 1|outlook| null|
| 2|outlook| null|
| 3|outlook| null|
| 4|outlook| null|
| 5|outlook| 1|
| 1| yahoo| 2|
| 2| yahoo| null|
| 3| yahoo| null|
| 4| yahoo| null|
| 5| yahoo| 3|
+----+-------+-----+
*/
val window = Window.partitionBy("key").orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow)
val all_date_key_count_cum_count = all_date_key_count.withColumn("cum_count", sum("count").over(window))
/*
+----+-------+-----+---------+
|date| key|count|cum_count|
+----+-------+-----+---------+
| 1| gmail| 5| 5|
| 2| gmail| null| 5|
| 3| gmail| 3| 8|
| 4| gmail| null| 8|
| 5| gmail| 1| 9|
| 1|outlook| null| null|
| 2|outlook| null| null|
| 3|outlook| null| null|
| 4|outlook| null| null|
| 5|outlook| 1| 1|
| 1| yahoo| 2| 2|
| 2| yahoo| null| 2|
| 3| yahoo| null| 2|
| 4| yahoo| null| 2|
| 5| yahoo| 3| 5|
+----+-------+-----+---------+
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment