-
-
Save AoJ/8ca6a6b42b3171c2ba11404f0066a4ec to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.countSerDe | |
import org.apache.spark.sql.catalyst.util._ | |
import org.apache.spark.sql.types._ | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.catalyst.InternalRow | |
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow | |
import org.apache.spark.sql.expressions.MutableAggregationBuffer | |
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction | |
@SQLUserDefinedType(udt = classOf[CountSerDeUDT]) | |
case class CountSerDeSQL(nSer: Int, nDeSer: Int) | |
class CountSerDeUDT extends UserDefinedType[CountSerDeSQL] { | |
def userClass: Class[CountSerDeSQL] = classOf[CountSerDeSQL] | |
override def typeName: String = "count-ser-de" | |
private[spark] override def asNullable: CountSerDeUDT = this | |
def sqlType: DataType = StructType( | |
StructField("nSer", IntegerType, false) :: | |
StructField("nDeSer", IntegerType, false) :: | |
Nil) | |
def serialize(sql: CountSerDeSQL): Any = { | |
val row = new GenericInternalRow(2) | |
row.setInt(0, 1 + sql.nSer) | |
row.setInt(1, sql.nDeSer) | |
row | |
} | |
def deserialize(any: Any): CountSerDeSQL = any match { | |
case row: InternalRow if (row.numFields == 2) => | |
CountSerDeSQL(row.getInt(0), 1 + row.getInt(1)) | |
case u => throw new Exception(s"failed to deserialize: $u") | |
} | |
override def equals(obj: Any): Boolean = { | |
obj match { | |
case _: CountSerDeUDT => true | |
case _ => false | |
} | |
} | |
override def hashCode(): Int = classOf[CountSerDeUDT].getName.hashCode() | |
} | |
case object CountSerDeUDT extends CountSerDeUDT | |
case object CountSerDeUDAF extends UserDefinedAggregateFunction { | |
def deterministic: Boolean = true | |
def inputSchema: StructType = StructType(StructField("x", DoubleType) :: Nil) | |
def bufferSchema: StructType = StructType(StructField("count-ser-de", CountSerDeUDT) :: Nil) | |
def dataType: DataType = CountSerDeUDT | |
def initialize(buf: MutableAggregationBuffer): Unit = { | |
buf(0) = CountSerDeSQL(0, 0) | |
} | |
def update(buf: MutableAggregationBuffer, input: Row): Unit = { | |
val sql = buf.getAs[CountSerDeSQL](0) | |
buf(0) = sql | |
} | |
def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = { | |
val sql1 = buf1.getAs[CountSerDeSQL](0) | |
val sql2 = buf2.getAs[CountSerDeSQL](0) | |
buf1(0) = CountSerDeSQL(sql1.nSer + sql2.nSer, sql1.nDeSer + sql2.nDeSer) | |
} | |
def evaluate(buf: Row): Any = buf.getAs[CountSerDeSQL](0) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
scala> import scala.util.Random.nextGaussian, org.apache.spark.countSerDe._ | |
import scala.util.Random.nextGaussian | |
import org.apache.spark.countSerDe._ | |
scala> val data = sc.parallelize(Vector.fill(1000){(nextGaussian, nextGaussian)}).toDF.as[(Double, Double)] | |
data: org.apache.spark.sql.Dataset[(Double, Double)] = [_1: double, _2: double] | |
scala> val udaf = CountSerDeUDAF | |
udaf: org.apache.spark.countSerDe.CountSerDeUDAF.type = CountSerDeUDAF | |
scala> val agg = data.agg(udaf($"_1")) | |
agg: org.apache.spark.sql.DataFrame = [countserdeudaf$(_1): count-ser-de] | |
scala> agg.first.getAs[CountSerDeSQL](0) | |
res4: org.apache.spark.countSerDe.CountSerDeSQL = CountSerDeSQL(1006,1006) | |
scala> spark.udf.register("countserde", udaf) | |
res1: org.apache.spark.sql.expressions.UserDefinedAggregateFunction = CountSerDeUDAF | |
scala> val agg = data.agg(expr("countserde(_1)")) | |
agg: org.apache.spark.sql.DataFrame = [countserde(_1): count-ser-de] | |
scala> agg.first.getAs[CountSerDeSQL](0) | |
res2: org.apache.spark.countSerDe.CountSerDeSQL = CountSerDeSQL(1006,1006) | |
scala> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
https://issues.apache.org/jira/browse/SPARK-27296