Skip to content

Instantly share code, notes, and snippets.

@lassade
Last active August 31, 2024 03:59
Show Gist options
  • Save lassade/8932a4cf929e3f9f69c3cfb346e56d48 to your computer and use it in GitHub Desktop.
Save lassade/8932a4cf929e3f9f69c3cfb346e56d48 to your computer and use it in GitHub Desktop.
A fixed (no allocations) hashmap that follows the Google Swiss Table design principles
const std = @import("std");
const builtin = @import("builtin");
pub const StringContext = struct {
pub inline fn hash(_: @This(), s: []const u8) u64 {
return std.hash_map.hashString(s);
}
pub inline fn eql(_: @This(), a: []const u8, b: []const u8) bool {
return std.mem.eql(u8, a, b);
}
};
pub fn FixedHashMap(comptime K: type, comptime V: type, comptime Context: type, comptime size: usize) type {
return struct {
data: *Data,
debug: Debug = .{},
// note: the current configuration allows for only 128 different hashes
const empty_slot = 0x80;
const tumbstone_slot = 0xfe;
const Bucket = @Vector(16, u8);
const bit: Bucket = @splat(0x80);
pub const bucket_count = @divFloor(size - 1, 16) + 1;
pub const len = bucket_count * 16;
pub const Data = struct {
bucket: [bucket_count]Bucket, // using @Vector will force an higher alignment
key: [len]K,
value: [len]V,
};
const DebugInt = if (builtin.mode == .Debug) u32 else u0;
const Debug = struct {
/// number of inserted elements, a good metric is if `load > (len * 4 / 5)` (80% load) the hash needs to grow
load: DebugInt = 0,
/// max distance bucket to find element, when `distance == (bucket_count - 1)`
/// you definity need to rehash unless `bucket_count` is small
distance: DebugInt = 0,
};
pub fn clear(self: *@This()) void {
@memset(std.mem.asBytes(&self.data.bucket), empty_slot);
self.debug = .{};
}
pub const GetOrPutResult = struct {
key_ptr: *K,
value_ptr: *V,
found_existing: bool,
};
pub fn getOrPut(self: *@This(), key: K) !GetOrPutResult {
const context = Context{};
const hash = context.hash(key);
const h2 = fingerprint(hash);
const h2v: Bucket = @splat(h2);
var h1 = @mod(hash, bucket_count);
var tombstone: usize = len;
var result: GetOrPutResult = undefined;
for (0..bucket_count) |d| {
const hv = self.data.bucket[h1];
// search in each matching hash slot
var mask: u16 = @bitCast(hv == h2v);
while (mask != 0) : (mask &= mask - 1) {
const i = (h1 << 4) | @ctz(mask);
if (context.eql(self.data.key[i], key)) {
result.found_existing = true;
result.key_ptr = &self.data.key[i];
result.value_ptr = &self.data.value[i];
self.debug.distance = @max(self.debug.distance, @as(DebugInt, @truncate(d)));
return result;
}
}
// process tombstone and empty slots at the same time, better for high loads
mask = @bitCast((hv & bit) == bit); // note: this translates to single SSE instruciton
while (mask != 0) {
const j = @ctz(mask);
var i = (h1 << 4) | j;
if (hv[j] == empty_slot) {
if (tombstone != len) i = tombstone;
@as([*]u8, @ptrCast(&self.data.bucket))[i] = h2;
self.data.key[i] = key;
result.found_existing = false;
result.key_ptr = &self.data.key[i];
result.value_ptr = &self.data.value[i];
self.debug.load +%= @truncate(1);
self.debug.distance = @max(self.debug.distance, @as(DebugInt, @truncate(d)));
return result;
} else if (tombstone == len) {
tombstone = i;
}
// high up-time without rehashing and with high-load optimization:
// clear all non-empty slots from the bitfield, next iteration will be
// guaranteed to return with a newly added position or go to the next bucket
mask &= ~@as(u16, @bitCast(hv != bit));
}
h1 = (h1 +% 1);
if (h1 >= bucket_count) h1 -%= bucket_count;
}
// edge case, worst possible cenario, if got this far you should just increase the table capacity
if (tombstone != len) {
@as([*]u8, @ptrCast(&self.data.bucket))[tombstone] = h2;
self.data.key[tombstone] = key;
result.found_existing = false;
result.key_ptr = &self.data.key[tombstone];
result.value_ptr = &self.data.value[tombstone];
self.debug.load +%= @truncate(1);
self.debug.distance = @truncate(bucket_count -% 1); // worse case maxium possible distance
return result;
}
return error.OutOfMemory;
}
pub fn remove(self: *@This(), key: K) bool {
const context = Context{};
const hash = context.hash(key);
const h2 = fingerprint(hash);
const h2v: Bucket = @splat(h2);
var h1 = @mod(hash, bucket_count);
for (0..bucket_count) |_| {
const hv = self.data.bucket[h1];
const empty: u16 = @bitCast(hv == bit);
var mask: u16 = @bitCast(hv == h2v);
while (mask != 0) : (mask &= mask - 1) {
const j = @ctz(mask);
const i = (h1 << 4) | j;
if (context.eql(self.data.key[i], key)) {
// note: a tombstone adds overhead after reusing the table for a long time
@as([*]u8, @ptrCast(&self.data.bucket))[i] = if (empty != 0)
// if in the same bucket of 16 hashes a empty position is avaible
// it's possible to flag the slot as empty instead of a tombstone
//
// this happens because all positions in the bucket are checked, and the empty slot will
// prevent the search to go to the next bucket
//
// this little trick should come at no amost cost, the `hv == bit` is executed 1 more time than usual but thats it
//
// note: high loads will fill up buckets, and once a bucket if filled
// theres no way to get rid of the tombstones right now
empty_slot
else
tumbstone_slot; // tombstone
// todo: maybe it's possible to define a strategy to that takes into considerations the neighbour buckets
// todo: it's possible but might not be worth: to clear all tombstones in previous bucket when this one is fully empty
self.debug.load -%= @truncate(1);
return true;
}
}
if (empty != 0) break;
h1 = (h1 +% 1);
if (h1 >= bucket_count) h1 -%= bucket_count;
}
return false;
}
inline fn fingerprint(hash: u64) u8 {
return @as(u8, @truncate(hash >> (64 - 7))) & 0x7f;
}
};
}
pub const CStringArena = struct {
buffer: []u8 = @constCast(""),
usage: usize = 0,
pub fn put(self: *@This(), str: []const u8) ![:0]const u8 {
const e = self.usage + str.len + 1;
if (e > self.buffer.len) return error.OutOfMemory;
var key: [:0]u8 = undefined;
key.ptr = @ptrCast(&self.buffer[self.usage]);
key.len = str.len;
self.usage = e;
@memcpy(key, str);
key[key.len] = 0;
return key;
}
};
test "cstringArenaAllocation" {
var buffer: [32]u8 = undefined;
var ss = CStringArena{ .buffer = &buffer };
const str: []const u8 = "Hello, World!";
const clone = try ss.put(str);
try std.testing.expect(std.mem.eql(u8, str, clone));
try std.testing.expectEqual(str.len, std.mem.span(clone.ptr).len);
}
test "hashMapBasicOperation" {
const StringHashMap = FixedHashMap([:0]const u8, usize, StringContext, 4);
var data: StringHashMap.Data = undefined;
var table: StringHashMap = .{ .data = &data };
table.clear();
const n = [4]StringHashMap.GetOrPutResult{
try table.getOrPut("0"),
try table.getOrPut("1"),
try table.getOrPut("2"),
try table.getOrPut("4"),
};
for (0..n.len) |i| {
std.debug.assert(!n[i].found_existing);
for (0..n.len) |j| {
if (i != j) {
std.debug.assert(n[i].value_ptr != n[j].value_ptr);
}
}
}
const m = [4]StringHashMap.GetOrPutResult{
try table.getOrPut("0"),
try table.getOrPut("1"),
try table.getOrPut("2"),
try table.getOrPut("4"),
};
for (0..n.len) |i| {
std.debug.assert(m[i].found_existing);
std.debug.assert(m[i].value_ptr == n[i].value_ptr);
}
try std.testing.expect(table.remove("1"));
try std.testing.expect(!table.remove("1"));
const r = try table.getOrPut("1");
try std.testing.expect(!r.found_existing);
}
test "hashMapFootprintSize" {
const StringHashMap = FixedHashMap([:0]const u8, usize, StringContext, 8);
if (builtin.mode == .Debug) {
try std.testing.expect(@sizeOf(StringHashMap) > @sizeOf(*StringHashMap.Data));
} else {
// in release mode the HashMap.Debug should be zero
try std.testing.expect(@sizeOf(StringHashMap) == @sizeOf(*StringHashMap.Data));
}
}
test "hashMapHighUsage" {
const StringHashMap = FixedHashMap([:0]const u8, usize, StringContext, 100);
const count = StringHashMap.len;
var xor = std.Random.Xoroshiro128.init(7);
const rng = xor.random();
var buffer: [count * 32]u8 = undefined;
var arena = CStringArena{ .buffer = &buffer };
// var keys = std.StringHashMap(void).init(std.testing.allocator);
// defer keys.deinit();
// generate a random input enough to fill the HashMap to 100%
var inputs: [count][:0]const u8 = undefined;
for (&inputs) |*value| {
var blob: [32]u8 = undefined;
// while (true) {
const len = rng.intRangeLessThan(usize, 8, 32);
for (0..len) |i| blob[i] = rng.intRangeAtMost(u8, 0, 255);
// const u = arena.usage;
const key = try arena.put(blob[0..len]);
// const r = try keys.getOrPut(key);
// if (r.found_existing) {
// arena.usage = u; // free last input
// continue;
// }
value.* = key;
// break;
// }
}
// random stream of unique
var rk: [count]usize = undefined;
var data: StringHashMap.Data = undefined;
var table: StringHashMap = .{ .data = &data };
const fill_rates = [_]usize{ 25, 50, 80, 90, 95, 100 };
for (&fill_rates) |fill_rate| {
const max = (count * fill_rate) / 100;
try std.testing.expect(max <= count);
// std.debug.print("fill rate: {d}%\n", .{@as(f32, @floatFromInt(max)) * 100.0 / @as(f32, @floatFromInt(count))});
table.clear();
// insert unique elements
for (0..max) |i| {
const r = try table.getOrPut(inputs[i]);
try std.testing.expect(!r.found_existing);
rk[i] = i;
}
if (builtin.mode == .Debug) try std.testing.expectEqual(max, table.debug.load);
// radomize inserted elements order
for (0..max * 4) |_| {
const a = rng.intRangeLessThan(usize, 0, max);
const b = rng.intRangeLessThan(usize, 0, max);
std.mem.swap(usize, &rk[a], &rk[b]);
}
// randomly remove 50% of the elements
for (0..max / 2) |i| _ = table.remove(inputs[rk[i]]);
if (builtin.mode == .Debug) try std.testing.expectEqual(max - (max / 2), table.debug.load);
for (0..max / 2) |i| _ = try table.getOrPut(inputs[rk[i]]); // then add them back
if (builtin.mode == .Debug) try std.testing.expectEqual(max, table.debug.load);
}
}
test "benchmark" {
const count = 4000;
const StringHashMap = FixedHashMap([]const u8, usize, StringContext, 8192); //count + count / 4);
var data: StringHashMap.Data = undefined;
var table: StringHashMap = .{ .data = &data };
table.clear();
// std.debug.print("size: {}\n", .{hm.data.key.len});
var xor = std.Random.Xoroshiro128.init(7);
const rng = xor.random();
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena.deinit();
const allocator = arena.allocator();
var add: [count][]u8 = undefined;
for (&add) |*value| {
const len = rng.intRangeAtMost(usize, 8, 32);
const blob = try allocator.alloc(u8, len);
for (blob) |*char| char.* = rng.intRangeAtMost(u8, 0, 255);
value.* = blob;
}
var rmv: [count / 2]usize = undefined;
for (0..rmv.len) |i| {
while (true) {
rmv[i] = rng.intRangeAtMost(usize, 0, add.len - 1);
if (std.mem.indexOfScalar(usize, rmv[0..i], rmv[i]) == null) break;
}
}
var timer = try std.time.Timer.start();
var samples: [10_000]u64 = undefined;
for (&samples) |*sample| {
table.clear();
timer.reset();
for (&add, 0..) |blob, id| {
const r = try table.getOrPut(blob);
if (!r.found_existing) r.value_ptr.* = id;
}
for (&rmv) |i| {
if (!table.remove(add[i])) return error.KeyNotFound;
}
for (&add) |blob| {
_ = try table.getOrPut(blob);
}
sample.* = timer.read();
}
var max: f64 = 0.0;
var avg: f64 = 0.0;
for (&samples) |sample| {
const t: f64 = @floatFromInt(sample / 1000);
max = @max(max, t);
avg += t;
}
avg /= @floatFromInt(samples.len);
std.debug.print("max: {d:.2} us, avg: {d:.2} us\n", .{ max, avg });
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment