Skip to content

Instantly share code, notes, and snippets.

@MLnick
Last active March 16, 2022 05:31
Show Gist options
  • Save MLnick/eca566604f2e4e3c6141 to your computer and use it in GitHub Desktop.
Save MLnick/eca566604f2e4e3c6141 to your computer and use it in GitHub Desktop.
Experimenting with Spark SQL UDAF - HyperLogLog UDAF for distinct counts, that stores the actual HLL for each row to allow further aggregation
class HyperLogLogStoreUDAF extends UserDefinedAggregateFunction {
override def inputSchema = new StructType()
.add("stringInput", BinaryType)
override def update(buffer: MutableAggregationBuffer, input: Row) = {
// This input Row only has a single column storing the input value in String (or other Binary data).
// We only update the buffer when the input value is not null.
if (!input.isNullAt(0)) {
if (buffer.isNullAt(0)) {
val newHLL = new HyperLogLog(0.05)
newHLL.offer(input.get(0))
buffer.update(0, newHLL)
}
else {
val updated = buffer.get(0).asInstanceOf[HyperLogLog]
updated.offer(input.get(0))
buffer.update(0, updated)
}
}
}
override def bufferSchema = new StructType().add("hll", MyHyperLogLogUDT)
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
// buffer1 and buffer2 have the same structure.
// We only update the buffer1 when the input buffer2's sum value is not null.
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
val hll = buffer2.get(0).asInstanceOf[HyperLogLog]
buffer1.update(0, hll)
}
else {
val left = buffer1.get(0).asInstanceOf[HyperLogLog]
val right = buffer2.get(0).asInstanceOf[HyperLogLog]
left.addAll(right)
buffer1.update(0, left)
}
}
}
override def initialize(buffer: MutableAggregationBuffer) = {
// The initial value of the sum is null.
buffer.update(0, null)
}
override def deterministic = true
override def evaluate(buffer: Row) = {
if (buffer.isNullAt(0)) {
null
}
else {
val hll = buffer.getAs[HyperLogLog](0)
InternalRow(hll.cardinality(), hll.getBytes)
}
}
override def dataType = new StructType()
.add("cardinality", LongType)
.add("hll", MyHyperLogLogUDT)
}
// copy-and-paste of internal Spark HyperLogLogUDT because it is [private] sql
case object MyHyperLogLogUDT extends UserDefinedType[HyperLogLog] {
override def sqlType: DataType = BinaryType
/** Since we are using HyperLogLog internally, usually it will not be called. */
override def serialize(obj: Any): Array[Byte] =
obj.asInstanceOf[HyperLogLog].getBytes
/** Since we are using HyperLogLog internally, usually it will not be called. */
override def deserialize(datum: Any): HyperLogLog =
HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]])
override def userClass: Class[HyperLogLog] = classOf[HyperLogLog]
}
object TestHLL extends App {
val conf = new SparkConf()
.setMaster("local[4]")
.setAppName("test")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
sqlContext.udf.register("hllcount", new HyperLogLogStoreUDAF)
val data = sc.parallelize(Seq("a", "b", "c", "d", "a", "b"), numSlices = 2).toDF("col1")
data.registerTempTable("test")
val res = sqlContext.sql("select hllcount(col1) from test")
println(res.show())
}
/*
+--------------------+
| _c0|
+--------------------+
|[4,com.clearsprin...|
+--------------------+
*/
@MLnick
Copy link
Author

MLnick commented Sep 12, 2015

Currently, returns a StructType with (cardinality, hll). Still looking into how to get the UDAF to accept input arguments of different types without having to write a whole new UDAF...

@itismewxg
Copy link

now all the function are set to be private, how can do such HLL aggregation with "HLL" as the output column?

@donghanz
Copy link

donghanz commented Mar 16, 2022

In AbstractGenericUDAFResolver interface, you can override getEvaluator() to accept different types of input and even assign different Resolvers to the corresponding input.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment