Last active
April 2, 2019 00:48
-
-
Save sadikovi/d06acc487cf59adac8203ca9fad4bd87 to your computer and use it in GitHub Desktop.
Spark UDT and UDAF with custom buffer type
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark | |
import scala.collection.mutable.ArrayBuffer | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.catalyst._ | |
import org.apache.spark.sql.catalyst.expressions._ | |
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData => GenericData} | |
import org.apache.spark.sql.expressions._ | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.unsafe.types._ | |
package object aggregate { | |
case class Point(mac: String, start: Long, end: Long) { | |
override def hashCode(): Int = { | |
31 * (31 * mac.hashCode) + start.hashCode | |
} | |
override def equals(other: Any): Boolean = other match { | |
case that: Point => this.mac == that.mac && this.start == that.start && this.end == that.end | |
case other => false | |
} | |
override def toString(): String = { | |
s"${getClass.getSimpleName}($mac, start=$start, end=$end)" | |
} | |
} | |
@SQLUserDefinedType(udt = classOf[BufferType]) | |
type Buffer = ArrayBuffer[Point] | |
private[spark] class BufferType extends UserDefinedType[Buffer] { | |
def sqlType: DataType = ArrayType(StructType( | |
StructField("mac", StringType, false) :: | |
StructField("start", LongType, false) :: | |
StructField("end", LongType, false) :: Nil)) | |
def serialize(obj: Any): Any = obj match { | |
case buffer: ArrayBuffer[_] => | |
val data = buffer.asInstanceOf[Buffer].map { point => | |
val arr = new Array[Any](3) | |
arr(0) = UTF8String.fromString(point.mac) | |
arr(1) = point.start | |
arr(2) = point.end | |
new GenericInternalRow(arr) | |
} | |
new GenericData(data) | |
case other => sys.error(s"Failed to serialize: $other") | |
} | |
def deserialize(datum: Any): Buffer = datum match { | |
case data: ArrayData => | |
val buf = new Buffer() | |
var next: InternalRow = null | |
for (i <- 0 until data.array.length) { | |
next = data.array(i).asInstanceOf[InternalRow] | |
buf.append(Point(next.getString(0), next.getLong(1), next.getLong(2))) | |
} | |
buf | |
case other => sys.error(s"Failed to deserialize: $other") | |
} | |
def userClass: Class[Buffer] = classOf[Buffer] | |
} | |
case object BufferType extends BufferType | |
// == UDAF == | |
class SimpleAggregate extends UserDefinedAggregateFunction { | |
override def inputSchema: StructType = StructType( | |
StructField("mac", StringType, true) :: | |
StructField("start", LongType, true) :: | |
StructField("end", LongType, true) :: Nil) | |
override def bufferSchema: StructType = StructType( | |
StructField("buffer", BufferType, true) :: Nil) | |
override def dataType: DataType = BufferType | |
override def deterministic: Boolean = true | |
override def initialize(buffer: MutableAggregationBuffer): Unit = { | |
buffer(0) = new Buffer() | |
} | |
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { | |
val buf = buffer(0).asInstanceOf[Buffer] | |
buf.append(Point(input.getString(0), input.getLong(1), input.getLong(2))) | |
buffer(0) = buf | |
} | |
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { | |
val buf1 = buffer1(0).asInstanceOf[Buffer] | |
val buf2 = buffer2(0).asInstanceOf[Buffer] | |
buf1.appendAll(buf2) | |
buffer1(0) = buf1 | |
} | |
override def evaluate(buffer: Row): Any = { | |
buffer(0).asInstanceOf[Buffer] | |
} | |
} | |
implicit val ordering = new Ordering[Point] { | |
override def compare(x: Point, y: Point): Int = { | |
if (x.start == y.start) { | |
if (x.end == y.end) 0 else if (x.end < y.end) -1 else 1 | |
} else { | |
if (x.start < y.start) -1 else 1 | |
} | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark | |
import org.apache.spark.sql.catalyst.util._ | |
import org.apache.spark.sql.types._ | |
@SQLUserDefinedType(udt = classOf[PointType]) | |
case class Point(mac: String, start: Long, end: Long) { | |
override def hashCode(): Int = { | |
31 * (31 * mac.hashCode) + start.hashCode | |
} | |
override def equals(other: Any): Boolean = other match { | |
case that: Point => this.mac == that.mac && this.start == that.start && this.end == that.end | |
case other => false | |
} | |
override def toString(): String = { | |
s"${getClass.getSimpleName}($mac, start=$start, end=$end)" | |
} | |
} | |
class PointType extends UserDefinedType[Point] { | |
def sqlType: DataType = StructType( | |
StructField("mac", StringType, false) :: | |
StructField("start", LongType, false) :: | |
StructField("end", LongType, false) :: Nil) | |
def serialize(obj: Any): Any = obj match { | |
case c @ Point(mac, start, end) => | |
println(s"Serialize: $c") | |
val arr = new Array[Any](3) | |
arr(0) = mac | |
arr(1) = start | |
arr(2) = end | |
new GenericArrayData(arr) | |
case other => sys.error(s"Failed to serialize: $other") | |
} | |
def deserialize(datum: Any): Point = datum match { | |
case c: ArrayData => | |
println(s"Deserialize: $datum -> $c") | |
Point( | |
c.array(0).asInstanceOf[String], | |
c.array(1).asInstanceOf[Long], | |
c.array(2).asInstanceOf[Long]) | |
case other => sys.error(s"Failed to deserialize: $other") | |
} | |
def userClass: Class[Point] = classOf[Point] | |
} | |
case object PointType extends PointType |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello,
What version of Spark is the aggregate.scala written for?
I'm trying to run that code against Spark 2.1 and getting a compilation error.