Created
February 15, 2018 13:43
-
-
Save lovasoa/0c3a180b15d169cf3d2d4bccacbdc620 to your computer and use it in GitHub Desktop.
spark UDAF for computing the mean of vectors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType | |
import org.apache.spark.ml.linalg.{Vector, Vectors} | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} | |
import org.apache.spark.sql.types._ | |
class VectorMean extends UserDefinedAggregateFunction { | |
// This is the input fields for your aggregate function. | |
override def inputSchema: org.apache.spark.sql.types.StructType = | |
StructType(StructField("value", VectorType) :: Nil) | |
// This is the internal fields you keep for computing your aggregate. | |
override def bufferSchema: StructType = StructType(Seq( | |
StructField("count", LongType), | |
StructField("sum", VectorType) | |
)) | |
// This is the output type of your aggregatation function. | |
override def dataType: DataType = VectorType | |
override def deterministic: Boolean = true | |
// This is the initial value for your buffer schema. | |
override def initialize(buffer: MutableAggregationBuffer): Unit = { | |
buffer(0) = 0L | |
buffer(1) = Vectors.zeros(0) | |
} | |
// This is how to update your buffer schema given an input. | |
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { | |
val oldSum = buffer.getAs[Long](0) | |
val oldVec = buffer.getAs[Vector](1) | |
val inputVec = input.getAs[Vector](0) | |
buffer(0) = oldSum + 1 | |
buffer(1) = vectorSum(oldVec, inputVec) | |
} | |
def vectorSum(a: Vector, b: Vector): Vector = { | |
val aa = if (a.size > 0) a else Vectors.zeros(b.size) | |
val bb = if (b.size > 0) b else Vectors.zeros(a.size) | |
Vectors.dense((aa.toArray, b.toArray).zipped.map(_ + _)) | |
} | |
// This is how to merge two objects with the bufferSchema type. | |
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { | |
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0) | |
buffer1(1) = vectorSum(buffer1.getAs[Vector](1), buffer2.getAs[Vector](1)) | |
} | |
// This is where you output the final value, given the final value of your bufferSchema. | |
override def evaluate(buffer: Row): Any = { | |
val count = buffer.getLong(0) | |
val vec = buffer.getAs[Vector](1) | |
Vectors.dense(vec.toArray.map(_ / count)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment