Last active
August 31, 2024 03:59
-
-
Save lassade/8932a4cf929e3f9f69c3cfb346e56d48 to your computer and use it in GitHub Desktop.
A fixed (no allocations) hashmap that follows the Google Swiss Table design principles
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const builtin = @import("builtin"); | |
pub const StringContext = struct { | |
pub inline fn hash(_: @This(), s: []const u8) u64 { | |
return std.hash_map.hashString(s); | |
} | |
pub inline fn eql(_: @This(), a: []const u8, b: []const u8) bool { | |
return std.mem.eql(u8, a, b); | |
} | |
}; | |
pub fn FixedHashMap(comptime K: type, comptime V: type, comptime Context: type, comptime size: usize) type { | |
return struct { | |
data: *Data, | |
debug: Debug = .{}, | |
// note: the current configuration allows for only 128 different hashes | |
const empty_slot = 0x80; | |
const tumbstone_slot = 0xfe; | |
const Bucket = @Vector(16, u8); | |
const bit: Bucket = @splat(0x80); | |
pub const bucket_count = @divFloor(size - 1, 16) + 1; | |
pub const len = bucket_count * 16; | |
pub const Data = struct { | |
bucket: [bucket_count]Bucket, // using @Vector will force an higher alignment | |
key: [len]K, | |
value: [len]V, | |
}; | |
const DebugInt = if (builtin.mode == .Debug) u32 else u0; | |
const Debug = struct { | |
/// number of inserted elements, a good metric is if `load > (len * 4 / 5)` (80% load) the hash needs to grow | |
load: DebugInt = 0, | |
/// max distance bucket to find element, when `distance == (bucket_count - 1)` | |
/// you definity need to rehash unless `bucket_count` is small | |
distance: DebugInt = 0, | |
}; | |
pub fn clear(self: *@This()) void { | |
@memset(std.mem.asBytes(&self.data.bucket), empty_slot); | |
self.debug = .{}; | |
} | |
pub const GetOrPutResult = struct { | |
key_ptr: *K, | |
value_ptr: *V, | |
found_existing: bool, | |
}; | |
pub fn getOrPut(self: *@This(), key: K) !GetOrPutResult { | |
const context = Context{}; | |
const hash = context.hash(key); | |
const h2 = fingerprint(hash); | |
const h2v: Bucket = @splat(h2); | |
var h1 = @mod(hash, bucket_count); | |
var tombstone: usize = len; | |
var result: GetOrPutResult = undefined; | |
for (0..bucket_count) |d| { | |
const hv = self.data.bucket[h1]; | |
// search in each matching hash slot | |
var mask: u16 = @bitCast(hv == h2v); | |
while (mask != 0) : (mask &= mask - 1) { | |
const i = (h1 << 4) | @ctz(mask); | |
if (context.eql(self.data.key[i], key)) { | |
result.found_existing = true; | |
result.key_ptr = &self.data.key[i]; | |
result.value_ptr = &self.data.value[i]; | |
self.debug.distance = @max(self.debug.distance, @as(DebugInt, @truncate(d))); | |
return result; | |
} | |
} | |
// process tombstone and empty slots at the same time, better for high loads | |
mask = @bitCast((hv & bit) == bit); // note: this translates to single SSE instruciton | |
while (mask != 0) { | |
const j = @ctz(mask); | |
var i = (h1 << 4) | j; | |
if (hv[j] == empty_slot) { | |
if (tombstone != len) i = tombstone; | |
@as([*]u8, @ptrCast(&self.data.bucket))[i] = h2; | |
self.data.key[i] = key; | |
result.found_existing = false; | |
result.key_ptr = &self.data.key[i]; | |
result.value_ptr = &self.data.value[i]; | |
self.debug.load +%= @truncate(1); | |
self.debug.distance = @max(self.debug.distance, @as(DebugInt, @truncate(d))); | |
return result; | |
} else if (tombstone == len) { | |
tombstone = i; | |
} | |
// high up-time without rehashing and with high-load optimization: | |
// clear all non-empty slots from the bitfield, next iteration will be | |
// guaranteed to return with a newly added position or go to the next bucket | |
mask &= ~@as(u16, @bitCast(hv != bit)); | |
} | |
h1 = (h1 +% 1); | |
if (h1 >= bucket_count) h1 -%= bucket_count; | |
} | |
// edge case, worst possible cenario, if got this far you should just increase the table capacity | |
if (tombstone != len) { | |
@as([*]u8, @ptrCast(&self.data.bucket))[tombstone] = h2; | |
self.data.key[tombstone] = key; | |
result.found_existing = false; | |
result.key_ptr = &self.data.key[tombstone]; | |
result.value_ptr = &self.data.value[tombstone]; | |
self.debug.load +%= @truncate(1); | |
self.debug.distance = @truncate(bucket_count -% 1); // worse case maxium possible distance | |
return result; | |
} | |
return error.OutOfMemory; | |
} | |
pub fn remove(self: *@This(), key: K) bool { | |
const context = Context{}; | |
const hash = context.hash(key); | |
const h2 = fingerprint(hash); | |
const h2v: Bucket = @splat(h2); | |
var h1 = @mod(hash, bucket_count); | |
for (0..bucket_count) |_| { | |
const hv = self.data.bucket[h1]; | |
const empty: u16 = @bitCast(hv == bit); | |
var mask: u16 = @bitCast(hv == h2v); | |
while (mask != 0) : (mask &= mask - 1) { | |
const j = @ctz(mask); | |
const i = (h1 << 4) | j; | |
if (context.eql(self.data.key[i], key)) { | |
// note: a tombstone adds overhead after reusing the table for a long time | |
@as([*]u8, @ptrCast(&self.data.bucket))[i] = if (empty != 0) | |
// if in the same bucket of 16 hashes a empty position is avaible | |
// it's possible to flag the slot as empty instead of a tombstone | |
// | |
// this happens because all positions in the bucket are checked, and the empty slot will | |
// prevent the search to go to the next bucket | |
// | |
// this little trick should come at no amost cost, the `hv == bit` is executed 1 more time than usual but thats it | |
// | |
// note: high loads will fill up buckets, and once a bucket if filled | |
// theres no way to get rid of the tombstones right now | |
empty_slot | |
else | |
tumbstone_slot; // tombstone | |
// todo: maybe it's possible to define a strategy to that takes into considerations the neighbour buckets | |
// todo: it's possible but might not be worth: to clear all tombstones in previous bucket when this one is fully empty | |
self.debug.load -%= @truncate(1); | |
return true; | |
} | |
} | |
if (empty != 0) break; | |
h1 = (h1 +% 1); | |
if (h1 >= bucket_count) h1 -%= bucket_count; | |
} | |
return false; | |
} | |
inline fn fingerprint(hash: u64) u8 { | |
return @as(u8, @truncate(hash >> (64 - 7))) & 0x7f; | |
} | |
}; | |
} | |
pub const CStringArena = struct { | |
buffer: []u8 = @constCast(""), | |
usage: usize = 0, | |
pub fn put(self: *@This(), str: []const u8) ![:0]const u8 { | |
const e = self.usage + str.len + 1; | |
if (e > self.buffer.len) return error.OutOfMemory; | |
var key: [:0]u8 = undefined; | |
key.ptr = @ptrCast(&self.buffer[self.usage]); | |
key.len = str.len; | |
self.usage = e; | |
@memcpy(key, str); | |
key[key.len] = 0; | |
return key; | |
} | |
}; | |
test "cstringArenaAllocation" { | |
var buffer: [32]u8 = undefined; | |
var ss = CStringArena{ .buffer = &buffer }; | |
const str: []const u8 = "Hello, World!"; | |
const clone = try ss.put(str); | |
try std.testing.expect(std.mem.eql(u8, str, clone)); | |
try std.testing.expectEqual(str.len, std.mem.span(clone.ptr).len); | |
} | |
test "hashMapBasicOperation" { | |
const StringHashMap = FixedHashMap([:0]const u8, usize, StringContext, 4); | |
var data: StringHashMap.Data = undefined; | |
var table: StringHashMap = .{ .data = &data }; | |
table.clear(); | |
const n = [4]StringHashMap.GetOrPutResult{ | |
try table.getOrPut("0"), | |
try table.getOrPut("1"), | |
try table.getOrPut("2"), | |
try table.getOrPut("4"), | |
}; | |
for (0..n.len) |i| { | |
std.debug.assert(!n[i].found_existing); | |
for (0..n.len) |j| { | |
if (i != j) { | |
std.debug.assert(n[i].value_ptr != n[j].value_ptr); | |
} | |
} | |
} | |
const m = [4]StringHashMap.GetOrPutResult{ | |
try table.getOrPut("0"), | |
try table.getOrPut("1"), | |
try table.getOrPut("2"), | |
try table.getOrPut("4"), | |
}; | |
for (0..n.len) |i| { | |
std.debug.assert(m[i].found_existing); | |
std.debug.assert(m[i].value_ptr == n[i].value_ptr); | |
} | |
try std.testing.expect(table.remove("1")); | |
try std.testing.expect(!table.remove("1")); | |
const r = try table.getOrPut("1"); | |
try std.testing.expect(!r.found_existing); | |
} | |
test "hashMapFootprintSize" { | |
const StringHashMap = FixedHashMap([:0]const u8, usize, StringContext, 8); | |
if (builtin.mode == .Debug) { | |
try std.testing.expect(@sizeOf(StringHashMap) > @sizeOf(*StringHashMap.Data)); | |
} else { | |
// in release mode the HashMap.Debug should be zero | |
try std.testing.expect(@sizeOf(StringHashMap) == @sizeOf(*StringHashMap.Data)); | |
} | |
} | |
test "hashMapHighUsage" { | |
const StringHashMap = FixedHashMap([:0]const u8, usize, StringContext, 100); | |
const count = StringHashMap.len; | |
var xor = std.Random.Xoroshiro128.init(7); | |
const rng = xor.random(); | |
var buffer: [count * 32]u8 = undefined; | |
var arena = CStringArena{ .buffer = &buffer }; | |
// var keys = std.StringHashMap(void).init(std.testing.allocator); | |
// defer keys.deinit(); | |
// generate a random input enough to fill the HashMap to 100% | |
var inputs: [count][:0]const u8 = undefined; | |
for (&inputs) |*value| { | |
var blob: [32]u8 = undefined; | |
// while (true) { | |
const len = rng.intRangeLessThan(usize, 8, 32); | |
for (0..len) |i| blob[i] = rng.intRangeAtMost(u8, 0, 255); | |
// const u = arena.usage; | |
const key = try arena.put(blob[0..len]); | |
// const r = try keys.getOrPut(key); | |
// if (r.found_existing) { | |
// arena.usage = u; // free last input | |
// continue; | |
// } | |
value.* = key; | |
// break; | |
// } | |
} | |
// random stream of unique | |
var rk: [count]usize = undefined; | |
var data: StringHashMap.Data = undefined; | |
var table: StringHashMap = .{ .data = &data }; | |
const fill_rates = [_]usize{ 25, 50, 80, 90, 95, 100 }; | |
for (&fill_rates) |fill_rate| { | |
const max = (count * fill_rate) / 100; | |
try std.testing.expect(max <= count); | |
// std.debug.print("fill rate: {d}%\n", .{@as(f32, @floatFromInt(max)) * 100.0 / @as(f32, @floatFromInt(count))}); | |
table.clear(); | |
// insert unique elements | |
for (0..max) |i| { | |
const r = try table.getOrPut(inputs[i]); | |
try std.testing.expect(!r.found_existing); | |
rk[i] = i; | |
} | |
if (builtin.mode == .Debug) try std.testing.expectEqual(max, table.debug.load); | |
// radomize inserted elements order | |
for (0..max * 4) |_| { | |
const a = rng.intRangeLessThan(usize, 0, max); | |
const b = rng.intRangeLessThan(usize, 0, max); | |
std.mem.swap(usize, &rk[a], &rk[b]); | |
} | |
// randomly remove 50% of the elements | |
for (0..max / 2) |i| _ = table.remove(inputs[rk[i]]); | |
if (builtin.mode == .Debug) try std.testing.expectEqual(max - (max / 2), table.debug.load); | |
for (0..max / 2) |i| _ = try table.getOrPut(inputs[rk[i]]); // then add them back | |
if (builtin.mode == .Debug) try std.testing.expectEqual(max, table.debug.load); | |
} | |
} | |
test "benchmark" { | |
const count = 4000; | |
const StringHashMap = FixedHashMap([]const u8, usize, StringContext, 8192); //count + count / 4); | |
var data: StringHashMap.Data = undefined; | |
var table: StringHashMap = .{ .data = &data }; | |
table.clear(); | |
// std.debug.print("size: {}\n", .{hm.data.key.len}); | |
var xor = std.Random.Xoroshiro128.init(7); | |
const rng = xor.random(); | |
var arena = std.heap.ArenaAllocator.init(std.testing.allocator); | |
defer arena.deinit(); | |
const allocator = arena.allocator(); | |
var add: [count][]u8 = undefined; | |
for (&add) |*value| { | |
const len = rng.intRangeAtMost(usize, 8, 32); | |
const blob = try allocator.alloc(u8, len); | |
for (blob) |*char| char.* = rng.intRangeAtMost(u8, 0, 255); | |
value.* = blob; | |
} | |
var rmv: [count / 2]usize = undefined; | |
for (0..rmv.len) |i| { | |
while (true) { | |
rmv[i] = rng.intRangeAtMost(usize, 0, add.len - 1); | |
if (std.mem.indexOfScalar(usize, rmv[0..i], rmv[i]) == null) break; | |
} | |
} | |
var timer = try std.time.Timer.start(); | |
var samples: [10_000]u64 = undefined; | |
for (&samples) |*sample| { | |
table.clear(); | |
timer.reset(); | |
for (&add, 0..) |blob, id| { | |
const r = try table.getOrPut(blob); | |
if (!r.found_existing) r.value_ptr.* = id; | |
} | |
for (&rmv) |i| { | |
if (!table.remove(add[i])) return error.KeyNotFound; | |
} | |
for (&add) |blob| { | |
_ = try table.getOrPut(blob); | |
} | |
sample.* = timer.read(); | |
} | |
var max: f64 = 0.0; | |
var avg: f64 = 0.0; | |
for (&samples) |sample| { | |
const t: f64 = @floatFromInt(sample / 1000); | |
max = @max(max, t); | |
avg += t; | |
} | |
avg /= @floatFromInt(samples.len); | |
std.debug.print("max: {d:.2} us, avg: {d:.2} us\n", .{ max, avg }); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment