Created
April 22, 2019 19:16
-
-
Save travishegner/33b5af41371eb1adf6f78556aaa48e3b to your computer and use it in GitHub Desktop.
User Defined Aggregate Function: Vector Sum
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.trilliumstaffing.hadoop.tools.udaf | |
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} | |
import org.apache.spark.sql.types.{DataType, StructField, StructType} | |
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType | |
import org.apache.spark.ml.linalg.{Vector, Vectors} | |
import breeze.linalg.{Vector => BV} | |
import org.apache.spark.sql.Row | |
class VectorSum extends UserDefinedAggregateFunction { | |
def dataType: DataType = VectorType | |
def deterministic: Boolean = true | |
def inputSchema: StructType = StructType(Array(StructField("value", VectorType))) | |
def bufferSchema: StructType = StructType(Array(StructField("sum", VectorType))) | |
def update(b: MutableAggregationBuffer, r: Row): Unit = { | |
val ag = Option(b.get(0).asInstanceOf[Vector]) | |
val vl = Option(r.get(0).asInstanceOf[Vector]) | |
b(0) = (ag, vl) match { | |
case (None, None) => null | |
case (Some(a), None) => a | |
case (None, Some(v)) => v | |
case (Some(a), Some(v)) => Vectors.dense((BV(a.toArray) + BV(v.toArray)).toArray) | |
} | |
} | |
def merge(b1: MutableAggregationBuffer, b2: Row): Unit = { | |
val ag1 = Option(b1.get(0).asInstanceOf[Vector]) | |
val ag2 = Option(b2.get(0).asInstanceOf[Vector]) | |
b1(0) = (ag1, ag2) match { | |
case (None, None) => null | |
case (Some(a1), None) => a1 | |
case (None, Some(a2)) => a2 | |
case (Some(a1), Some(a2)) => Vectors.dense((BV(a1.toArray) + BV(a2.toArray)).toArray) | |
} | |
} | |
def initialize(b: MutableAggregationBuffer): Unit = { | |
b(0) = null | |
} | |
def evaluate(b: Row): Vector = b.get(0).asInstanceOf[Vector] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment