Last active
September 15, 2025 09:36
-
-
Save mzaks/6130612af6d7c5fd18ac92bb7f750771 to your computer and use it in GitHub Desktop.
Mojo HyperLogLog performance optimisation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from benchmark import benchmark, Unit, keep, Bencher, Bench, BenchConfig, BenchId | |
from hyperloglog import HyperLogLog | |
fn hash_int(x: Int) -> Int: | |
"""Simple hash function for integers.""" | |
var h = x | |
h = ((h >> 16) ^ h) * 0x45D9F3B | |
h = ((h >> 16) ^ h) * 0x45D9F3B | |
h = (h >> 16) ^ h | |
return h | |
fn benchmark_add_sparse() raises: | |
"""Benchmark adding elements while HLL remains in sparse mode.""" | |
var hll = HyperLogLog[14]() | |
for i in range(1000): # Small enough to stay sparse | |
hll.add_hash(hash_int(i)) | |
fn benchmark_add_dense() raises: | |
"""Benchmark adding elements in dense mode.""" | |
var hll = HyperLogLog[14]() | |
for i in range(100_000): # Large enough to trigger dense mode | |
hll.add_hash(hash_int(i)) | |
fn benchmark_cardinality_sparse() raises: | |
"""Benchmark cardinality estimation in sparse mode.""" | |
var hll = HyperLogLog[14]() | |
for i in range(1000): | |
hll.add_hash(hash_int(i)) | |
var c = hll.cardinality() | |
keep(c) # Prevent optimization away | |
fn benchmark_cardinality_dense_() raises: | |
"""Benchmark cardinality estimation in dense mode.""" | |
var hll = HyperLogLog[14]() | |
for i in range(100_000): | |
hll.add_hash(hash_int(i)) | |
var c = hll.cardinality() | |
keep(c) # Prevent optimization away | |
@parameter | |
fn benchmark_cardinality_dense(mut b: Bencher) raises: | |
var hll = HyperLogLog[14]() | |
for i in range(100_000): | |
hll.add_hash(hash_int(i)) | |
@always_inline | |
@parameter | |
fn call_fn(): | |
var c = hll.cardinality() | |
keep(c) # Prevent optimization away | |
b.iter[call_fn]() | |
fn benchmark_merge_sparse() raises: | |
"""Benchmark merging two sparse HLLs.""" | |
var hll1 = HyperLogLog[14]() | |
var hll2 = HyperLogLog[14]() | |
for i in range(1000): | |
hll1.add_hash(hash_int(i)) | |
hll2.add_hash(hash_int(i + 1000)) | |
hll1.merge(hll2) | |
fn benchmark_merge_dense() raises: | |
"""Benchmark merging two dense HLLs.""" | |
var hll1 = HyperLogLog[14]() | |
var hll2 = HyperLogLog[14]() | |
for i in range(100_000): | |
hll1.add_hash(hash_int(i)) | |
hll2.add_hash(hash_int(i + 100_000)) | |
hll1.merge(hll2) | |
fn main() raises: | |
# print("Running HyperLogLog Benchmarks...") | |
# print("\nSparse Mode Operations:") | |
# print("-----------------------") | |
# print("\nAdding elements (sparse):") | |
# var report = benchmark.run[benchmark_add_sparse]() | |
# report.print(Unit.ms) | |
# print("\nCardinality estimation (sparse):") | |
# report = benchmark.run[benchmark_cardinality_sparse]() | |
# report.print(Unit.ms) | |
# print("\nMerging HLLs (sparse):") | |
# report = benchmark.run[benchmark_merge_sparse]() | |
# report.print(Unit.ms) | |
# print("\nDense Mode Operations:") | |
# print("---------------------") | |
# print("\nAdding elements (dense):") | |
# report = benchmark.run[benchmark_add_dense]() | |
# report.print(Unit.ms) | |
print("\nCardinality estimation (dense):") | |
# report = benchmark.run[benchmark_cardinality_dense]() | |
# report.print(Unit.ms) | |
var m = Bench( | |
BenchConfig( | |
# out_file=_dir_of_current_file() / "bench_dict_string.csv", | |
num_repetitions=5, | |
) | |
) | |
m.bench_function[benchmark_cardinality_dense]( | |
BenchId("benchmark_cardinality_dense EN") | |
) | |
m.dump_report() | |
# print("\nMerging HLLs (dense):") | |
# report = benchmark.run[benchmark_merge_dense]() | |
# report.print(Unit.ms) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import log2 | |
# This is a SIMD-optimized beta constant calculation using vectorized polynomial evaluation. | |
# It is used to calculate the beta constant for a given number of registers. | |
# The beta constant is used to calculate the beta distribution. | |
# The beta distribution is a probability distribution that is used to model the probability of a random variable taking on a value between 0 and 1. | |
# The beta distribution is defined by two parameters, alpha and beta. | |
# The alpha parameter is the shape parameter of the distribution. | |
# The beta parameter is the scale parameter of the distribution. | |
fn get_beta[P: Int](ez: Float32) -> Float32: | |
"""SIMD-optimized beta constant calculation using vectorized polynomial evaluation. | |
""" | |
@parameter | |
if 4 <= P <= 16: | |
var zl = log2(ez + 1) | |
var zl2 = zl * zl | |
var zl3 = zl2 * zl | |
var zl4 = zl3 * zl | |
var zl5 = zl4 * zl | |
var zl6 = zl5 * zl | |
var zl7 = zl6 * zl | |
var z = SIMD[DType.float32, 8](ez, zl, zl2, zl3, zl4, zl5, zl6, zl7) | |
@parameter | |
if P == 4: | |
var c = SIMD[DType.float32, 8]( | |
-0.582581413904517, | |
-1.935300357560050, | |
11.079323758035073, | |
-22.131357446444323, | |
22.505391846630037, | |
-12.000723834917984, | |
3.220579408194167, | |
-0.342225302271235, | |
) | |
return (c * z).reduce_add() | |
elif P == 5: | |
var c = SIMD[DType.float32, 8]( | |
-0.7518999460733967, | |
-0.9590030077748760, | |
5.5997371322141607, | |
-8.2097636999765520, | |
6.5091254894472037, | |
-2.6830293734323729, | |
0.5612891113138221, | |
-0.0463331622196545, | |
) | |
return (c * z).reduce_add() | |
elif P == 6: | |
var c = SIMD[DType.float32, 8]( | |
29.8257900969619634, | |
-31.3287083337725925, | |
-10.5942523036582283, | |
-11.5720125689099618, | |
3.8188754373907492, | |
-2.4160130328530811, | |
0.4542208940970826, | |
-0.0575155452020420, | |
) | |
return (c * z).reduce_add() | |
elif P == 7: | |
var c = SIMD[DType.float32, 8]( | |
2.8102921290820060, | |
-3.9780498518175995, | |
1.3162680041351582, | |
-3.9252486335805901, | |
2.0080835753946471, | |
-0.7527151937556955, | |
0.1265569894242751, | |
-0.0109946438726240, | |
) | |
return (c * z).reduce_add() | |
elif P == 8: | |
var c = SIMD[DType.float32, 8]( | |
1.00633544887550519, | |
-2.00580666405112407, | |
1.64369749366514117, | |
-2.70560809940566172, | |
1.39209980244222598, | |
-0.46470374272183190, | |
0.07384282377269775, | |
-0.00578554885254223, | |
) | |
return (c * z).reduce_add() | |
elif P == 9: | |
var c = SIMD[DType.float32, 8]( | |
-0.09415657458167959, | |
-0.78130975924550528, | |
1.71514946750712460, | |
-1.73711250406516338, | |
0.86441508489048924, | |
-0.23819027465047218, | |
0.03343448400269076, | |
-0.00207858528178157, | |
) | |
return (c * z).reduce_add() | |
elif P == 10: | |
var c = SIMD[DType.float32, 8]( | |
-0.25935400670790054, | |
-0.52598301999805808, | |
1.48933034925876839, | |
-1.29642714084993571, | |
0.62284756217221615, | |
-0.15672326770251041, | |
0.02054415903878563, | |
-0.00112488483925502, | |
) | |
return (c * z).reduce_add() | |
elif P == 11: | |
var c = SIMD[DType.float32, 8]( | |
-0.432325553856025, | |
-0.108450736399632, | |
0.609156550741120, | |
-0.0165687801845180, | |
-0.0795829341087617, | |
0.0471830602102918, | |
-0.00781372902346934, | |
0.000584268708489995, | |
) | |
return (c * z).reduce_add() | |
elif P == 12: | |
var c = SIMD[DType.float32, 8]( | |
-0.384979202588598, | |
0.183162233114364, | |
0.130396688841854, | |
0.0704838927629266, | |
-0.0089589397146453, | |
0.0113010036741605, | |
-0.00194285569591290, | |
0.000225435774024964, | |
) | |
return (c * z).reduce_add() | |
elif P == 13: | |
var c = SIMD[DType.float32, 8]( | |
-0.41655270946462997, | |
-0.22146677040685156, | |
0.38862131236999947, | |
0.45340979746062371, | |
-0.36264738324476375, | |
0.12304650053558529, | |
-0.01701540384555510, | |
0.00102750367080838, | |
) | |
return (c * z).reduce_add() | |
elif P == 14: | |
var c = SIMD[DType.float32, 8]( | |
-0.371009760230692, | |
0.00978811941207509, | |
0.185796293324165, | |
0.203015527328432, | |
-0.116710521803686, | |
0.0431106699492820, | |
-0.00599583540511831, | |
0.000449704299509437, | |
) | |
return (c * z).reduce_add() | |
elif P == 15: | |
var c = SIMD[DType.float32, 8]( | |
-0.38215145543875273, | |
-0.89069400536090837, | |
0.37602335774678869, | |
0.99335977440682377, | |
-0.65577441638318956, | |
0.18332342129703610, | |
-0.02241529633062872, | |
0.00121399789330194, | |
) | |
return (c * z).reduce_add() | |
elif P == 16: | |
var c = SIMD[DType.float32, 8]( | |
-0.37331876643753059, | |
-1.41704077448122989, | |
0.40729184796612533, | |
1.56152033906584164, | |
-0.99242233534286128, | |
0.26064681399483092, | |
-0.03053811369682807, | |
0.00155770210179105, | |
) | |
return (c * z).reduce_add() | |
# Unreachable | |
alias num_registers = 1 << P | |
return 0.7213 / (1.0 + 1.079 / Float32(num_registers)) | |
else: | |
# For larger register counts, use the standard beta correction | |
alias num_registers = 1 << P | |
return 0.7213 / (1.0 + 1.079 / Float32(num_registers)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import log2, exp2 | |
from collections import List, Set | |
from beta import get_beta | |
from bit import count_leading_zeros | |
struct HyperLogLog[P: Int](ImplicitlyCopyable): | |
""" | |
HyperLogLog using the LogLog-Beta algorithm for cardinality estimation. | |
Provides bias correction for improved accuracy. | |
""" | |
alias precision = P # Number of bits used for register indexing (4-16) | |
alias max_zeros = 64 - P # Maximum possible leading zeros | |
alias m = 1 << P # Number of registers | |
var registers: List[UInt8] # Dense representation | |
var sparse_set: Set[Int] # Sparse representation for low cardinality | |
var is_sparse: Bool # Tracks current representation mode | |
@staticmethod | |
@parameter | |
fn _get_alpha() -> Float32: | |
"""Get alpha constant at compile time based on precision.""" | |
@parameter | |
if P == 4: | |
return 0.673 | |
elif P == 5: | |
return 0.697 | |
elif P == 6: | |
return 0.709 | |
else: | |
return 0.7213 / (1.0 + 1.079 / Float32(1 << P)) | |
fn __init__(out self) raises: | |
"""Initialize HyperLogLog with compile-time precision P.""" | |
# Compile-time validation | |
constrained[P >= 4 and P <= 16, "Precision must be between 4 and 16"]() | |
# Initialize empty data structures | |
self.registers = List[UInt8]() | |
self.sparse_set = Set[Int]() | |
self.is_sparse = True | |
fn __copyinit__(out self, existing: Self): | |
"""Copy-initialize from an existing HyperLogLog.""" | |
self.is_sparse = existing.is_sparse | |
self.sparse_set = Set[Int]() | |
if self.is_sparse: | |
self.registers = List[UInt8]() | |
for item in existing.sparse_set: | |
self.sparse_set.add(item) | |
else: | |
self.registers = List[UInt8]() | |
for i in range(len(existing.registers)): | |
self.registers.append(existing.registers[i]) | |
fn add_hash(mut self, hash: Int): | |
"""Incorporate a new hash value into the sketch.""" | |
var hash_int = Int(hash) | |
if self.is_sparse: | |
# Convert to dense representation when sparse set grows too large | |
alias threshold = 1 << (Self.precision - 3) | |
if len(self.sparse_set) >= threshold: | |
self._convert_to_dense() | |
self._add_to_dense(hash_int) | |
else: | |
self.sparse_set.add(hash_int) | |
else: | |
self._add_to_dense(hash_int) | |
fn _get_bucket_and_zeros(mut self, hash_int: Int) -> Tuple[Int, UInt8]: | |
"""Extract the bucket index and count the leading zeros.""" | |
alias mask: Int = (1 << Self.precision) - 1 | |
var bucket: Int = (hash_int >> (64 - Self.precision)) & mask | |
var pattern: Int = (hash_int << Self.precision) | (1 << (Self.precision - 1)) | |
var zeros: UInt8 = UInt8(count_leading_zeros(pattern) + 1) | |
return bucket, zeros | |
fn _add_to_dense(mut self, hash_int: Int): | |
"""Update the dense registers using the given hash.""" | |
if len(self.registers) == 0: | |
self._convert_to_dense() | |
var bucket: Int | |
var zeros: UInt8 | |
bucket, zeros = self._get_bucket_and_zeros(hash_int) | |
if self.registers[bucket] < zeros: | |
self.registers[bucket] = zeros | |
fn _convert_to_dense(mut self): | |
"""Switch from sparse to dense representation.""" | |
self.registers = List[UInt8]() | |
# Initialize all registers to 0 | |
for _ in range(Self.m): | |
self.registers.append(0) | |
# Process all hashes from sparse set | |
for h in self.sparse_set: | |
var value = h | |
var bucket: Int | |
var zeros: UInt8 | |
bucket, zeros = self._get_bucket_and_zeros(value) | |
if self.registers[bucket] < zeros: | |
self.registers[bucket] = zeros | |
self.is_sparse = False | |
self.sparse_set.clear() | |
fn cardinality(self) -> Int: | |
"""Estimate number of unique elements.""" | |
if self.is_sparse: | |
return len(self.sparse_set) | |
var sum: Float32 = 0.0 | |
var ez: Float32 = 0.0 # Count of empty registers | |
# Calculate harmonic mean of register values | |
alias vector_width = min(Self.m, 64) | |
for i in range(0, Self.m, vector_width): | |
var reg = self.registers.unsafe_ptr().load[width=vector_width](i) | |
var reg_flag = reg.eq(0) | |
ez += Float32(reg_flag.cast[DType.uint8]().reduce_add()) | |
var exp_reg = reg_flag.select(SIMD[DType.float16, vector_width](0), 1 / exp2(reg.cast[DType.float16]())) | |
sum += exp_reg.reduce_add().cast[DType.float32]() | |
# alias unroll_step = 1 << 4 | |
# for i in range(0, Self.m, unroll_step): | |
# @parameter | |
# for j in range(unroll_step): | |
# var reg = self.registers[i + j] | |
# var reg_flag = Int(reg == 0) | |
# ez += reg_flag | |
# sum += (1 - reg_flag) / exp2(Float32(reg)) | |
# Apply LogLog-Beta bias correction | |
return Int(Self._get_alpha() * Self.m * (Self.m - ez) / (get_beta[Self.precision](ez) + sum)) | |
fn merge(mut self, other: Self) raises: | |
"""Merge another sketch into this one.""" | |
# Precision is now compile-time guaranteed to match | |
# Handle dense mode merging | |
if not self.is_sparse or not other.is_sparse: | |
if self.is_sparse: | |
self._convert_to_dense() | |
if other.is_sparse: | |
# Merge sparse into dense | |
for h in other.sparse_set: | |
var value = h | |
var bucket: Int | |
var zeros: UInt8 | |
bucket, zeros = self._get_bucket_and_zeros(value) | |
if self.registers[bucket] < zeros: | |
self.registers[bucket] = zeros | |
else: | |
# Merge dense into dense | |
for i in range(Self.m): | |
if self.registers[i] < other.registers[i]: | |
self.registers[i] = other.registers[i] | |
else: | |
# Both are sparse, simply merge sets | |
for h in other.sparse_set: | |
self.sparse_set.add(h) | |
fn serialize(mut self) -> List[UInt8]: | |
"""Serialize the sketch into a byte list.""" | |
var buffer = List[UInt8]() | |
# Write header | |
buffer.append(UInt8(Self.precision)) | |
buffer.append(UInt8(1 if self.is_sparse else 0)) | |
if self.is_sparse: | |
# Write sparse set size | |
var count = len(self.sparse_set) | |
buffer.append(UInt8((count >> 24) & 0xFF)) | |
buffer.append(UInt8((count >> 16) & 0xFF)) | |
buffer.append(UInt8((count >> 8) & 0xFF)) | |
buffer.append(UInt8(count & 0xFF)) | |
# Write sparse set values | |
for h in self.sparse_set: | |
var value = h | |
for shift in range(56, -8, -8): | |
buffer.append(UInt8((value >> shift) & 0xFF)) | |
else: | |
# Ensure we're in dense mode | |
if len(self.registers) == 0: | |
self._convert_to_dense() | |
# Write register values | |
for i in range(Self.m): | |
buffer.append(self.registers[i]) | |
return buffer^ | |
@staticmethod | |
fn deserialize[TargetP: Int](buffer: List[UInt8]) raises -> HyperLogLog[TargetP]: | |
"""Deserialize a sketch from the given byte list.""" | |
if len(buffer) < 2: | |
raise Error("Invalid serialized data: buffer too short") | |
# Read and validate precision matches type parameter | |
var stored_precision = Int(buffer[0]) | |
if stored_precision != TargetP: | |
raise Error("Stored precision does not match expected precision") | |
var is_sparse = buffer[1] == 1 | |
var hll = HyperLogLog[TargetP]() | |
hll.is_sparse = is_sparse | |
if is_sparse: | |
if len(buffer) < 6: | |
raise Error("Invalid serialized data: sparse buffer too short") | |
# Read sparse set size | |
var count = (Int(buffer[2]) << 24) | (Int(buffer[3]) << 16) | | |
(Int(buffer[4]) << 8) | Int(buffer[5]) | |
# Read sparse set values | |
var pos = 6 | |
for _ in range(count): | |
if pos + 8 > len(buffer): | |
raise Error("Invalid serialized data: incomplete sparse item") | |
var value: Int = 0 | |
for j in range(8): | |
value = (value << 8) | Int(buffer[pos + j]) | |
hll.sparse_set.add(value) | |
pos += 8 | |
else: | |
# Verify buffer size | |
alias expected_size = (1 << TargetP) + 2 | |
if len(buffer) != expected_size: | |
raise Error("Invalid serialized data: wrong buffer length") | |
# Read register values | |
alias num_registers = 1 << TargetP | |
hll.registers = List[UInt8]() | |
for i in range(num_registers): | |
hll.registers.append(buffer[i + 2]) | |
return hll |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment