Skip to content

Instantly share code, notes, and snippets.

@mzaks
Last active September 15, 2025 09:36
Show Gist options
  • Save mzaks/6130612af6d7c5fd18ac92bb7f750771 to your computer and use it in GitHub Desktop.
Save mzaks/6130612af6d7c5fd18ac92bb7f750771 to your computer and use it in GitHub Desktop.
Mojo HyperLogLog performance optimisation
from benchmark import benchmark, Unit, keep, Bencher, Bench, BenchConfig, BenchId
from hyperloglog import HyperLogLog
fn hash_int(x: Int) -> Int:
"""Simple hash function for integers."""
var h = x
h = ((h >> 16) ^ h) * 0x45D9F3B
h = ((h >> 16) ^ h) * 0x45D9F3B
h = (h >> 16) ^ h
return h
fn benchmark_add_sparse() raises:
"""Benchmark adding elements while HLL remains in sparse mode."""
var hll = HyperLogLog[14]()
for i in range(1000): # Small enough to stay sparse
hll.add_hash(hash_int(i))
fn benchmark_add_dense() raises:
"""Benchmark adding elements in dense mode."""
var hll = HyperLogLog[14]()
for i in range(100_000): # Large enough to trigger dense mode
hll.add_hash(hash_int(i))
fn benchmark_cardinality_sparse() raises:
"""Benchmark cardinality estimation in sparse mode."""
var hll = HyperLogLog[14]()
for i in range(1000):
hll.add_hash(hash_int(i))
var c = hll.cardinality()
keep(c) # Prevent optimization away
fn benchmark_cardinality_dense_() raises:
"""Benchmark cardinality estimation in dense mode."""
var hll = HyperLogLog[14]()
for i in range(100_000):
hll.add_hash(hash_int(i))
var c = hll.cardinality()
keep(c) # Prevent optimization away
@parameter
fn benchmark_cardinality_dense(mut b: Bencher) raises:
var hll = HyperLogLog[14]()
for i in range(100_000):
hll.add_hash(hash_int(i))
@always_inline
@parameter
fn call_fn():
var c = hll.cardinality()
keep(c) # Prevent optimization away
b.iter[call_fn]()
fn benchmark_merge_sparse() raises:
"""Benchmark merging two sparse HLLs."""
var hll1 = HyperLogLog[14]()
var hll2 = HyperLogLog[14]()
for i in range(1000):
hll1.add_hash(hash_int(i))
hll2.add_hash(hash_int(i + 1000))
hll1.merge(hll2)
fn benchmark_merge_dense() raises:
"""Benchmark merging two dense HLLs."""
var hll1 = HyperLogLog[14]()
var hll2 = HyperLogLog[14]()
for i in range(100_000):
hll1.add_hash(hash_int(i))
hll2.add_hash(hash_int(i + 100_000))
hll1.merge(hll2)
fn main() raises:
# print("Running HyperLogLog Benchmarks...")
# print("\nSparse Mode Operations:")
# print("-----------------------")
# print("\nAdding elements (sparse):")
# var report = benchmark.run[benchmark_add_sparse]()
# report.print(Unit.ms)
# print("\nCardinality estimation (sparse):")
# report = benchmark.run[benchmark_cardinality_sparse]()
# report.print(Unit.ms)
# print("\nMerging HLLs (sparse):")
# report = benchmark.run[benchmark_merge_sparse]()
# report.print(Unit.ms)
# print("\nDense Mode Operations:")
# print("---------------------")
# print("\nAdding elements (dense):")
# report = benchmark.run[benchmark_add_dense]()
# report.print(Unit.ms)
print("\nCardinality estimation (dense):")
# report = benchmark.run[benchmark_cardinality_dense]()
# report.print(Unit.ms)
var m = Bench(
BenchConfig(
# out_file=_dir_of_current_file() / "bench_dict_string.csv",
num_repetitions=5,
)
)
m.bench_function[benchmark_cardinality_dense](
BenchId("benchmark_cardinality_dense EN")
)
m.dump_report()
# print("\nMerging HLLs (dense):")
# report = benchmark.run[benchmark_merge_dense]()
# report.print(Unit.ms)
from math import log2
# This is a SIMD-optimized beta constant calculation using vectorized polynomial evaluation.
# It is used to calculate the beta constant for a given number of registers.
# The beta constant is used to calculate the beta distribution.
# The beta distribution is a probability distribution that is used to model the probability of a random variable taking on a value between 0 and 1.
# The beta distribution is defined by two parameters, alpha and beta.
# The alpha parameter is the shape parameter of the distribution.
# The beta parameter is the scale parameter of the distribution.
fn get_beta[P: Int](ez: Float32) -> Float32:
"""SIMD-optimized beta constant calculation using vectorized polynomial evaluation.
"""
@parameter
if 4 <= P <= 16:
var zl = log2(ez + 1)
var zl2 = zl * zl
var zl3 = zl2 * zl
var zl4 = zl3 * zl
var zl5 = zl4 * zl
var zl6 = zl5 * zl
var zl7 = zl6 * zl
var z = SIMD[DType.float32, 8](ez, zl, zl2, zl3, zl4, zl5, zl6, zl7)
@parameter
if P == 4:
var c = SIMD[DType.float32, 8](
-0.582581413904517,
-1.935300357560050,
11.079323758035073,
-22.131357446444323,
22.505391846630037,
-12.000723834917984,
3.220579408194167,
-0.342225302271235,
)
return (c * z).reduce_add()
elif P == 5:
var c = SIMD[DType.float32, 8](
-0.7518999460733967,
-0.9590030077748760,
5.5997371322141607,
-8.2097636999765520,
6.5091254894472037,
-2.6830293734323729,
0.5612891113138221,
-0.0463331622196545,
)
return (c * z).reduce_add()
elif P == 6:
var c = SIMD[DType.float32, 8](
29.8257900969619634,
-31.3287083337725925,
-10.5942523036582283,
-11.5720125689099618,
3.8188754373907492,
-2.4160130328530811,
0.4542208940970826,
-0.0575155452020420,
)
return (c * z).reduce_add()
elif P == 7:
var c = SIMD[DType.float32, 8](
2.8102921290820060,
-3.9780498518175995,
1.3162680041351582,
-3.9252486335805901,
2.0080835753946471,
-0.7527151937556955,
0.1265569894242751,
-0.0109946438726240,
)
return (c * z).reduce_add()
elif P == 8:
var c = SIMD[DType.float32, 8](
1.00633544887550519,
-2.00580666405112407,
1.64369749366514117,
-2.70560809940566172,
1.39209980244222598,
-0.46470374272183190,
0.07384282377269775,
-0.00578554885254223,
)
return (c * z).reduce_add()
elif P == 9:
var c = SIMD[DType.float32, 8](
-0.09415657458167959,
-0.78130975924550528,
1.71514946750712460,
-1.73711250406516338,
0.86441508489048924,
-0.23819027465047218,
0.03343448400269076,
-0.00207858528178157,
)
return (c * z).reduce_add()
elif P == 10:
var c = SIMD[DType.float32, 8](
-0.25935400670790054,
-0.52598301999805808,
1.48933034925876839,
-1.29642714084993571,
0.62284756217221615,
-0.15672326770251041,
0.02054415903878563,
-0.00112488483925502,
)
return (c * z).reduce_add()
elif P == 11:
var c = SIMD[DType.float32, 8](
-0.432325553856025,
-0.108450736399632,
0.609156550741120,
-0.0165687801845180,
-0.0795829341087617,
0.0471830602102918,
-0.00781372902346934,
0.000584268708489995,
)
return (c * z).reduce_add()
elif P == 12:
var c = SIMD[DType.float32, 8](
-0.384979202588598,
0.183162233114364,
0.130396688841854,
0.0704838927629266,
-0.0089589397146453,
0.0113010036741605,
-0.00194285569591290,
0.000225435774024964,
)
return (c * z).reduce_add()
elif P == 13:
var c = SIMD[DType.float32, 8](
-0.41655270946462997,
-0.22146677040685156,
0.38862131236999947,
0.45340979746062371,
-0.36264738324476375,
0.12304650053558529,
-0.01701540384555510,
0.00102750367080838,
)
return (c * z).reduce_add()
elif P == 14:
var c = SIMD[DType.float32, 8](
-0.371009760230692,
0.00978811941207509,
0.185796293324165,
0.203015527328432,
-0.116710521803686,
0.0431106699492820,
-0.00599583540511831,
0.000449704299509437,
)
return (c * z).reduce_add()
elif P == 15:
var c = SIMD[DType.float32, 8](
-0.38215145543875273,
-0.89069400536090837,
0.37602335774678869,
0.99335977440682377,
-0.65577441638318956,
0.18332342129703610,
-0.02241529633062872,
0.00121399789330194,
)
return (c * z).reduce_add()
elif P == 16:
var c = SIMD[DType.float32, 8](
-0.37331876643753059,
-1.41704077448122989,
0.40729184796612533,
1.56152033906584164,
-0.99242233534286128,
0.26064681399483092,
-0.03053811369682807,
0.00155770210179105,
)
return (c * z).reduce_add()
# Unreachable
alias num_registers = 1 << P
return 0.7213 / (1.0 + 1.079 / Float32(num_registers))
else:
# For larger register counts, use the standard beta correction
alias num_registers = 1 << P
return 0.7213 / (1.0 + 1.079 / Float32(num_registers))
from math import log2, exp2
from collections import List, Set
from beta import get_beta
from bit import count_leading_zeros
struct HyperLogLog[P: Int](ImplicitlyCopyable):
"""
HyperLogLog using the LogLog-Beta algorithm for cardinality estimation.
Provides bias correction for improved accuracy.
"""
alias precision = P # Number of bits used for register indexing (4-16)
alias max_zeros = 64 - P # Maximum possible leading zeros
alias m = 1 << P # Number of registers
var registers: List[UInt8] # Dense representation
var sparse_set: Set[Int] # Sparse representation for low cardinality
var is_sparse: Bool # Tracks current representation mode
@staticmethod
@parameter
fn _get_alpha() -> Float32:
"""Get alpha constant at compile time based on precision."""
@parameter
if P == 4:
return 0.673
elif P == 5:
return 0.697
elif P == 6:
return 0.709
else:
return 0.7213 / (1.0 + 1.079 / Float32(1 << P))
fn __init__(out self) raises:
"""Initialize HyperLogLog with compile-time precision P."""
# Compile-time validation
constrained[P >= 4 and P <= 16, "Precision must be between 4 and 16"]()
# Initialize empty data structures
self.registers = List[UInt8]()
self.sparse_set = Set[Int]()
self.is_sparse = True
fn __copyinit__(out self, existing: Self):
"""Copy-initialize from an existing HyperLogLog."""
self.is_sparse = existing.is_sparse
self.sparse_set = Set[Int]()
if self.is_sparse:
self.registers = List[UInt8]()
for item in existing.sparse_set:
self.sparse_set.add(item)
else:
self.registers = List[UInt8]()
for i in range(len(existing.registers)):
self.registers.append(existing.registers[i])
fn add_hash(mut self, hash: Int):
"""Incorporate a new hash value into the sketch."""
var hash_int = Int(hash)
if self.is_sparse:
# Convert to dense representation when sparse set grows too large
alias threshold = 1 << (Self.precision - 3)
if len(self.sparse_set) >= threshold:
self._convert_to_dense()
self._add_to_dense(hash_int)
else:
self.sparse_set.add(hash_int)
else:
self._add_to_dense(hash_int)
fn _get_bucket_and_zeros(mut self, hash_int: Int) -> Tuple[Int, UInt8]:
"""Extract the bucket index and count the leading zeros."""
alias mask: Int = (1 << Self.precision) - 1
var bucket: Int = (hash_int >> (64 - Self.precision)) & mask
var pattern: Int = (hash_int << Self.precision) | (1 << (Self.precision - 1))
var zeros: UInt8 = UInt8(count_leading_zeros(pattern) + 1)
return bucket, zeros
fn _add_to_dense(mut self, hash_int: Int):
"""Update the dense registers using the given hash."""
if len(self.registers) == 0:
self._convert_to_dense()
var bucket: Int
var zeros: UInt8
bucket, zeros = self._get_bucket_and_zeros(hash_int)
if self.registers[bucket] < zeros:
self.registers[bucket] = zeros
fn _convert_to_dense(mut self):
"""Switch from sparse to dense representation."""
self.registers = List[UInt8]()
# Initialize all registers to 0
for _ in range(Self.m):
self.registers.append(0)
# Process all hashes from sparse set
for h in self.sparse_set:
var value = h
var bucket: Int
var zeros: UInt8
bucket, zeros = self._get_bucket_and_zeros(value)
if self.registers[bucket] < zeros:
self.registers[bucket] = zeros
self.is_sparse = False
self.sparse_set.clear()
fn cardinality(self) -> Int:
"""Estimate number of unique elements."""
if self.is_sparse:
return len(self.sparse_set)
var sum: Float32 = 0.0
var ez: Float32 = 0.0 # Count of empty registers
# Calculate harmonic mean of register values
alias vector_width = min(Self.m, 64)
for i in range(0, Self.m, vector_width):
var reg = self.registers.unsafe_ptr().load[width=vector_width](i)
var reg_flag = reg.eq(0)
ez += Float32(reg_flag.cast[DType.uint8]().reduce_add())
var exp_reg = reg_flag.select(SIMD[DType.float16, vector_width](0), 1 / exp2(reg.cast[DType.float16]()))
sum += exp_reg.reduce_add().cast[DType.float32]()
# alias unroll_step = 1 << 4
# for i in range(0, Self.m, unroll_step):
# @parameter
# for j in range(unroll_step):
# var reg = self.registers[i + j]
# var reg_flag = Int(reg == 0)
# ez += reg_flag
# sum += (1 - reg_flag) / exp2(Float32(reg))
# Apply LogLog-Beta bias correction
return Int(Self._get_alpha() * Self.m * (Self.m - ez) / (get_beta[Self.precision](ez) + sum))
fn merge(mut self, other: Self) raises:
"""Merge another sketch into this one."""
# Precision is now compile-time guaranteed to match
# Handle dense mode merging
if not self.is_sparse or not other.is_sparse:
if self.is_sparse:
self._convert_to_dense()
if other.is_sparse:
# Merge sparse into dense
for h in other.sparse_set:
var value = h
var bucket: Int
var zeros: UInt8
bucket, zeros = self._get_bucket_and_zeros(value)
if self.registers[bucket] < zeros:
self.registers[bucket] = zeros
else:
# Merge dense into dense
for i in range(Self.m):
if self.registers[i] < other.registers[i]:
self.registers[i] = other.registers[i]
else:
# Both are sparse, simply merge sets
for h in other.sparse_set:
self.sparse_set.add(h)
fn serialize(mut self) -> List[UInt8]:
"""Serialize the sketch into a byte list."""
var buffer = List[UInt8]()
# Write header
buffer.append(UInt8(Self.precision))
buffer.append(UInt8(1 if self.is_sparse else 0))
if self.is_sparse:
# Write sparse set size
var count = len(self.sparse_set)
buffer.append(UInt8((count >> 24) & 0xFF))
buffer.append(UInt8((count >> 16) & 0xFF))
buffer.append(UInt8((count >> 8) & 0xFF))
buffer.append(UInt8(count & 0xFF))
# Write sparse set values
for h in self.sparse_set:
var value = h
for shift in range(56, -8, -8):
buffer.append(UInt8((value >> shift) & 0xFF))
else:
# Ensure we're in dense mode
if len(self.registers) == 0:
self._convert_to_dense()
# Write register values
for i in range(Self.m):
buffer.append(self.registers[i])
return buffer^
@staticmethod
fn deserialize[TargetP: Int](buffer: List[UInt8]) raises -> HyperLogLog[TargetP]:
"""Deserialize a sketch from the given byte list."""
if len(buffer) < 2:
raise Error("Invalid serialized data: buffer too short")
# Read and validate precision matches type parameter
var stored_precision = Int(buffer[0])
if stored_precision != TargetP:
raise Error("Stored precision does not match expected precision")
var is_sparse = buffer[1] == 1
var hll = HyperLogLog[TargetP]()
hll.is_sparse = is_sparse
if is_sparse:
if len(buffer) < 6:
raise Error("Invalid serialized data: sparse buffer too short")
# Read sparse set size
var count = (Int(buffer[2]) << 24) | (Int(buffer[3]) << 16) |
(Int(buffer[4]) << 8) | Int(buffer[5])
# Read sparse set values
var pos = 6
for _ in range(count):
if pos + 8 > len(buffer):
raise Error("Invalid serialized data: incomplete sparse item")
var value: Int = 0
for j in range(8):
value = (value << 8) | Int(buffer[pos + j])
hll.sparse_set.add(value)
pos += 8
else:
# Verify buffer size
alias expected_size = (1 << TargetP) + 2
if len(buffer) != expected_size:
raise Error("Invalid serialized data: wrong buffer length")
# Read register values
alias num_registers = 1 << TargetP
hll.registers = List[UInt8]()
for i in range(num_registers):
hll.registers.append(buffer[i + 2])
return hll
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment