Created
December 19, 2016 16:35
-
-
Save josep2/e04faecad61f08b7717aa8185b8f44d9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ProductSum extends UserDefinedAggregateFunction { | |
override def inputSchema: org.apache.spark.sql.types.StructType = | |
StructType(Seq(StructField("value", DoubleType), StructField("value2", DoubleType))) | |
override def bufferSchema: StructType = StructType( | |
StructField("product", DoubleType) :: Nil | |
) | |
override def dataType: DataType = DoubleType | |
override def deterministic: Boolean = true | |
override def initialize(buffer: MutableAggregationBuffer): Unit = { | |
buffer(0) = 0.0 | |
} | |
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { | |
buffer(0) = buffer.getAs[Double](0) + input.getAs[Double](0)*input.getAs[Double](1) | |
} | |
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { | |
buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0) | |
} | |
override def evaluate(buffer: Row): Any = { | |
buffer.getDouble(0) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment