Skip to content

Instantly share code, notes, and snippets.

Last active January 26, 2023 04:31
Show Gist options
  • Save tzachz/c976a1080b6379ef861c142c16f1364a to your computer and use it in GitHub Desktop.
Save tzachz/c976a1080b6379ef861c142c16f1364a to your computer and use it in GitHub Desktop.
Apache Spark UserDefinedAggregateFunction combining maps
import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, Row, SQLContext}
* UDAF combining maps, overriding any duplicate key with "latest" value
* @param keyType DataType of Map key
* @param valueType DataType of Value key
* @param merge function to merge values of identical keys
* @tparam K key type
* @tparam V value type
class CombineMaps[K, V](keyType: DataType, valueType: DataType, merge: (V, V) => V) extends UserDefinedAggregateFunction {
override def inputSchema: StructType = new StructType().add("map", dataType)
override def bufferSchema: StructType = inputSchema
override def dataType: DataType = MapType(keyType, valueType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0, Map[K, V]())
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val map1 = buffer.getAs[Map[K, V]](0)
val map2 = input.getAs[Map[K, V]](0)
val result = map1 ++ { case (k,v) => k -> map1.get(k).map(merge(v, _)).getOrElse(v) }
buffer.update(0, result)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = update(buffer1, buffer2)
override def evaluate(buffer: Row): Any = buffer.getAs[Map[K, V]](0)
object Example {
def main(args: Array[String]): Unit = {
import org.apache.spark.sql.functions._
val sc: SparkContext = new SparkContext("local", "test")
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val input = sc.parallelize(Seq(
(1, Map("John" -> 12.5, "Alice" -> 5.5)),
(1, Map("Jim" -> 16.5, "Alice" -> 4.0)),
(2, Map("John" -> 13.5, "Jim" -> 2.5))
)).toDF("id", "scoreMap")
// instantiate a CombineMaps with the relevant types:
val combineMaps = new CombineMaps[String, Double](StringType, DoubleType, _ + _)
// groupBy and aggregate
val result = input.groupBy("id").agg(combineMaps(col("scoreMap"))) = false)
// +---+--------------------------------------------+
// |id |CombineMaps(scoreMap) |
// +---+--------------------------------------------+
// |1 |Map(John -> 12.5, Alice -> 9.5, Jim -> 16.5)|
// |2 |Map(John -> 13.5, Jim -> 2.5) |
// +---+--------------------------------------------+
Copy link

man, this is great!!
Thanks for sharing.

Copy link

d3r1v3d commented Apr 9, 2018

Little late to the party, but shouldn't evaluate use the generic, parameter types?

override def evaluate(buffer: Row): Any = buffer.getAs[Map[K, V]](0)

Copy link

tzachz commented May 8, 2018

oops, @d3r1v3d - you're right! Thanks, fixed 👍

Copy link

Very nice example, thank you! I have a question, though. What purpose do the input and buffer schemas serve? I can't seem to get them to do anything. I had expected inputSchema to evaluate whether the correct columns and types were passed in, but that doesn't seem to be true.

Copy link

dedcode commented Jun 15, 2020

Some cells can be null, so you probably need to check for that using if (!input.isNullAt(0))
This was very helpful 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment