Skip to content

Instantly share code, notes, and snippets.

@josep2
Created December 19, 2016 16:35
Show Gist options
  • Save josep2/e04faecad61f08b7717aa8185b8f44d9 to your computer and use it in GitHub Desktop.
Save josep2/e04faecad61f08b7717aa8185b8f44d9 to your computer and use it in GitHub Desktop.
class ProductSum extends UserDefinedAggregateFunction {
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(Seq(StructField("value", DoubleType), StructField("value2", DoubleType)))
override def bufferSchema: StructType = StructType(
StructField("product", DoubleType) :: Nil
)
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0.0
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Double](0) + input.getAs[Double](0)*input.getAs[Double](1)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
}
override def evaluate(buffer: Row): Any = {
buffer.getDouble(0)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment