Last active
April 5, 2022 03:05
-
-
Save saswata-dutta/bdc24a95b18f58fef809e26af4e198a0 to your computer and use it in GitHub Desktop.
Spark cumulative sum with gaps
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import spark.implicits._ | |
import org.apache.spark.sql.expressions.Window | |
import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.types.IntegerType | |
import org.apache.spark.sql.internal.SQLConf.SHUFFLE_PARTITIONS | |
// need to set the shuffle partitions correctly based on data sizes and distribution | |
// to spread evenly across executors; default is 200 | |
spark.conf.set(SHUFFLE_PARTITIONS.key, 100) | |
val all_dates = (1 to 5).toDF("date") | |
val all_events = ( | |
List.fill(5)(("gmail", 1)) ++ List.fill(3)(("gmail", 3)) ++ List.fill(1)(("gmail", 5)) ++ | |
List.fill(2)(("yahoo", 1)) ++ List.fill(3)(("yahoo", 5)) ++ | |
List.fill(1)(("outlook", 5)) | |
).toDF("key", "date") | |
// might need salting for only popular keys like gmail here to evenly spread it out across executors | |
// cant salt blindly need some pre analysis once : say keys that happen more than mill times | |
// and decide buckets accordingly so all 'new' keys equally bucketised | |
val skewed_keys = array(lit("gmail"), lit("yahoo")) | |
val salted_events = all_events.withColumn("salt", | |
when(array_contains(skewed_keys, col("key")), (rand * 10).cast(IntegerType)). | |
otherwise(lit(0)) | |
).withColumn("key_salt", concat(col("key"), lit("__"), col("salt"))).drop("salt", "key") | |
/* | |
+----+----------+ | |
|date| key_salt| | |
+----+----------+ | |
| 1| gmail__7| | |
| 1| gmail__8| | |
| 1| gmail__0| | |
| 1| gmail__5| | |
| 1| gmail__0| | |
| 3| gmail__6| | |
| 3| gmail__2| | |
| 3| gmail__6| | |
| 5| gmail__7| | |
| 1| yahoo__0| | |
| 1| yahoo__4| | |
| 5| yahoo__8| | |
| 5| yahoo__1| | |
| 5| yahoo__8| | |
| 5|outlook__0| | |
+----+----------+ | |
*/ | |
val date_key_count = | |
salted_events.groupBy("date", "key_salt"). | |
count(). | |
withColumn("key", element_at(split(col("key_salt"), "__"), 1)). | |
repartition(YYY, col("date"), col("key")). // repart to smaller num of tasks | |
groupBy("date", "key").agg(sum("count").as("count")) | |
/* | |
+----+-------+-----+ | |
|date| key|count| | |
+----+-------+-----+ | |
| 5|outlook| 1| | |
| 5| yahoo| 3| | |
| 3| gmail| 3| | |
| 1| gmail| 5| | |
| 5| gmail| 1| | |
| 1| yahoo| 2| | |
+----+-------+-----+ | |
*/ | |
val all_uniq_keys = date_key_count.select("key").distinct | |
val all_date_key = all_dates.crossJoin(all_uniq_keys) | |
val all_date_key_count = all_date_key.join(events_date_key_count, Seq("date", "key"), "left") | |
/* | |
+----+-------+-----+ | |
|date| key|count| | |
+----+-------+-----+ | |
| 1| gmail| 5| | |
| 2| gmail| null| | |
| 3| gmail| 3| | |
| 4| gmail| null| | |
| 5| gmail| 1| | |
| 1|outlook| null| | |
| 2|outlook| null| | |
| 3|outlook| null| | |
| 4|outlook| null| | |
| 5|outlook| 1| | |
| 1| yahoo| 2| | |
| 2| yahoo| null| | |
| 3| yahoo| null| | |
| 4| yahoo| null| | |
| 5| yahoo| 3| | |
+----+-------+-----+ | |
*/ | |
val window = Window.partitionBy("key").orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow) | |
val all_date_key_count_cum_count = all_date_key_count.withColumn("cum_count", sum("count").over(window)) | |
/* | |
+----+-------+-----+---------+ | |
|date| key|count|cum_count| | |
+----+-------+-----+---------+ | |
| 1| gmail| 5| 5| | |
| 2| gmail| null| 5| | |
| 3| gmail| 3| 8| | |
| 4| gmail| null| 8| | |
| 5| gmail| 1| 9| | |
| 1|outlook| null| null| | |
| 2|outlook| null| null| | |
| 3|outlook| null| null| | |
| 4|outlook| null| null| | |
| 5|outlook| 1| 1| | |
| 1| yahoo| 2| 2| | |
| 2| yahoo| null| 2| | |
| 3| yahoo| null| 2| | |
| 4| yahoo| null| 2| | |
| 5| yahoo| 3| 5| | |
+----+-------+-----+---------+ | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment