Last active
March 16, 2022 05:31
-
-
Save MLnick/eca566604f2e4e3c6141 to your computer and use it in GitHub Desktop.
Experimenting with Spark SQL UDAF - HyperLogLog UDAF for distinct counts, that stores the actual HLL for each row to allow further aggregation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class HyperLogLogStoreUDAF extends UserDefinedAggregateFunction { | |
override def inputSchema = new StructType() | |
.add("stringInput", BinaryType) | |
override def update(buffer: MutableAggregationBuffer, input: Row) = { | |
// This input Row only has a single column storing the input value in String (or other Binary data). | |
// We only update the buffer when the input value is not null. | |
if (!input.isNullAt(0)) { | |
if (buffer.isNullAt(0)) { | |
val newHLL = new HyperLogLog(0.05) | |
newHLL.offer(input.get(0)) | |
buffer.update(0, newHLL) | |
} | |
else { | |
val updated = buffer.get(0).asInstanceOf[HyperLogLog] | |
updated.offer(input.get(0)) | |
buffer.update(0, updated) | |
} | |
} | |
} | |
override def bufferSchema = new StructType().add("hll", MyHyperLogLogUDT) | |
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { | |
// buffer1 and buffer2 have the same structure. | |
// We only update the buffer1 when the input buffer2's sum value is not null. | |
if (!buffer2.isNullAt(0)) { | |
if (buffer1.isNullAt(0)) { | |
val hll = buffer2.get(0).asInstanceOf[HyperLogLog] | |
buffer1.update(0, hll) | |
} | |
else { | |
val left = buffer1.get(0).asInstanceOf[HyperLogLog] | |
val right = buffer2.get(0).asInstanceOf[HyperLogLog] | |
left.addAll(right) | |
buffer1.update(0, left) | |
} | |
} | |
} | |
override def initialize(buffer: MutableAggregationBuffer) = { | |
// The initial value of the sum is null. | |
buffer.update(0, null) | |
} | |
override def deterministic = true | |
override def evaluate(buffer: Row) = { | |
if (buffer.isNullAt(0)) { | |
null | |
} | |
else { | |
val hll = buffer.getAs[HyperLogLog](0) | |
InternalRow(hll.cardinality(), hll.getBytes) | |
} | |
} | |
override def dataType = new StructType() | |
.add("cardinality", LongType) | |
.add("hll", MyHyperLogLogUDT) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// copy-and-paste of internal Spark HyperLogLogUDT because it is [private] sql | |
case object MyHyperLogLogUDT extends UserDefinedType[HyperLogLog] { | |
override def sqlType: DataType = BinaryType | |
/** Since we are using HyperLogLog internally, usually it will not be called. */ | |
override def serialize(obj: Any): Array[Byte] = | |
obj.asInstanceOf[HyperLogLog].getBytes | |
/** Since we are using HyperLogLog internally, usually it will not be called. */ | |
override def deserialize(datum: Any): HyperLogLog = | |
HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]]) | |
override def userClass: Class[HyperLogLog] = classOf[HyperLogLog] | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object TestHLL extends App { | |
val conf = new SparkConf() | |
.setMaster("local[4]") | |
.setAppName("test") | |
val sc = new SparkContext(conf) | |
val sqlContext = new SQLContext(sc) | |
import sqlContext.implicits._ | |
sqlContext.udf.register("hllcount", new HyperLogLogStoreUDAF) | |
val data = sc.parallelize(Seq("a", "b", "c", "d", "a", "b"), numSlices = 2).toDF("col1") | |
data.registerTempTable("test") | |
val res = sqlContext.sql("select hllcount(col1) from test") | |
println(res.show()) | |
} | |
/* | |
+--------------------+ | |
| _c0| | |
+--------------------+ | |
|[4,com.clearsprin...| | |
+--------------------+ | |
*/ |
now all the function are set to be private, how can do such HLL aggregation with "HLL" as the output column?
In AbstractGenericUDAFResolver
interface, you can override getEvaluator()
to accept different types of input and even assign different Resolvers to the corresponding input.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Currently, returns a
StructType
with(cardinality, hll)
. Still looking into how to get the UDAF to accept input arguments of different types without having to write a whole new UDAF...