Skip to content

Instantly share code, notes, and snippets.

@tzachz
Last active January 26, 2023 04:31
Show Gist options
  • Select an option

  • Save tzachz/c976a1080b6379ef861c142c16f1364a to your computer and use it in GitHub Desktop.

Select an option

Save tzachz/c976a1080b6379ef861c142c16f1364a to your computer and use it in GitHub Desktop.
Apache Spark UserDefinedAggregateFunction combining maps
import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, Row, SQLContext}
/***
* UDAF combining maps, overriding any duplicate key with "latest" value
* @param keyType DataType of Map key
* @param valueType DataType of Value key
* @param merge function to merge values of identical keys
* @tparam K key type
* @tparam V value type
*/
class CombineMaps[K, V](keyType: DataType, valueType: DataType, merge: (V, V) => V) extends UserDefinedAggregateFunction {
override def inputSchema: StructType = new StructType().add("map", dataType)
override def bufferSchema: StructType = inputSchema
override def dataType: DataType = MapType(keyType, valueType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = buffer.update(0, Map[K, V]())
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val map1 = buffer.getAs[Map[K, V]](0)
val map2 = input.getAs[Map[K, V]](0)
val result = map1 ++ map2.map { case (k,v) => k -> map1.get(k).map(merge(v, _)).getOrElse(v) }
buffer.update(0, result)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = update(buffer1, buffer2)
override def evaluate(buffer: Row): Any = buffer.getAs[Map[K, V]](0)
}
object Example {
def main(args: Array[String]): Unit = {
import org.apache.spark.sql.functions._
val sc: SparkContext = new SparkContext("local", "test")
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val input = sc.parallelize(Seq(
(1, Map("John" -> 12.5, "Alice" -> 5.5)),
(1, Map("Jim" -> 16.5, "Alice" -> 4.0)),
(2, Map("John" -> 13.5, "Jim" -> 2.5))
)).toDF("id", "scoreMap")
// instantiate a CombineMaps with the relevant types:
val combineMaps = new CombineMaps[String, Double](StringType, DoubleType, _ + _)
// groupBy and aggregate
val result = input.groupBy("id").agg(combineMaps(col("scoreMap")))
result.show(truncate = false)
// +---+--------------------------------------------+
// |id |CombineMaps(scoreMap) |
// +---+--------------------------------------------+
// |1 |Map(John -> 12.5, Alice -> 9.5, Jim -> 16.5)|
// |2 |Map(John -> 13.5, Jim -> 2.5) |
// +---+--------------------------------------------+
}
}
@mrbrahman
Copy link
Copy Markdown

This is great! Thank you.

@fjavieralba
Copy link
Copy Markdown

man, this is great!!
Thanks for sharing.

@d3r1v3d
Copy link
Copy Markdown

d3r1v3d commented Apr 9, 2018

Little late to the party, but shouldn't evaluate use the generic, parameter types?

override def evaluate(buffer: Row): Any = buffer.getAs[Map[K, V]](0)

@tzachz
Copy link
Copy Markdown
Author

tzachz commented May 8, 2018

oops, @d3r1v3d - you're right! Thanks, fixed ๐Ÿ‘

@bradleysmithllc
Copy link
Copy Markdown

Very nice example, thank you! I have a question, though. What purpose do the input and buffer schemas serve? I can't seem to get them to do anything. I had expected inputSchema to evaluate whether the correct columns and types were passed in, but that doesn't seem to be true.

@dedcode
Copy link
Copy Markdown

dedcode commented Jun 15, 2020

Some cells can be null, so you probably need to check for that using if (!input.isNullAt(0))
This was very helpful ๐Ÿ‘

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment