Skip to content

Instantly share code, notes, and snippets.

@michaeleisel
Created October 17, 2024 17:45
Show Gist options
  • Save michaeleisel/50e9890d6a489f7613cedd9d9f821972 to your computer and use it in GitHub Desktop.
Save michaeleisel/50e9890d6a489f7613cedd9d9f821972 to your computer and use it in GitHub Desktop.
// This is a hash map implementation for maps with string keys only, designed to maximize performance. Similar in some ways to LLVM's StringMap
#include <iostream>
#include <unordered_map>
#include <memory>
#include <absl/container/flat_hash_map.h>
#ifndef HASH_H
#define HASH_H
#include <string_view>
#include <vector>
using Value = size_t;
// template <class Value, class Hasher = std::hash<std::string_view>>
// template <class Hasher = std::hash<std::string_view>>
class Hash {
public:
using HashCode = uint64_t;
using Summary = uint8_t;
static const size_t kGroupSize = 8;
static const Summary kTombstoneSummary = 1;
static const Summary kEmptySummary = 0; // 0-initialization leaves everything empty, as it should
template <typename T>
static T *ptr_add(T *ptr, size_t offset) {
return reinterpret_cast<T *>(reinterpret_cast<uint8_t *>(ptr) + offset);
}
static const void *ptr_add_void(const void *ptr, size_t offset) {
return static_cast<const void *>(static_cast<const uint8_t *>(ptr) + offset);
}
static void *ptr_add_void(void *ptr, size_t offset) {
return static_cast<void *>(static_cast<uint8_t *>(ptr) + offset);
}
class Iterator {
public:
size_t index_;
Hash &hash_;
Iterator(size_t index, Hash &hash) : index_(index), hash_(hash) {
skip_to_next_filled_or_end();
}
void skip_to_next_filled_or_end() {
size_t end_index = hash_.capacity();
while (index_ != end_index) {
auto &group = hash_.metadata_groups_[index_ / kGroupSize];
auto offset = index_ % kGroupSize;
if (group.summaries[offset] != kEmptySummary && group.summaries[offset] != kTombstoneSummary) {
return;
}
index_++;
}
}
Iterator operator++() {
index_++;
skip_to_next_filled_or_end();
return *this;
}
Iterator operator++(int) {
Iterator temp = *this;
++(*this);
return temp;
}
bool operator==(const Iterator& other) const {
return &(hash_) == &(other.hash_) && index_ == other.index_;
}
std::pair<const std::string_view, Value &> operator *() {
// Note: for some reason it had errors when I returned a reference to the string_view
auto &group = hash_.metadata_groups_[index_ / kGroupSize];
auto &bucket = *group.bucket_ptrs[index_ % kGroupSize];
return {bucket.key(), bucket.value};
}
};
Iterator begin() {
return {0, *this};
}
Iterator end() {
return {capacity(), *this};
}
struct KeyValueBucket {
HashCode hash_code;
size_t key_length;
Value value;
template <class... Args>
KeyValueBucket(HashCode hash_code, size_t key_length, Args&&... value_args) : hash_code(hash_code), key_length(key_length), value(std::forward<Args>(value_args)...) {}
// We could also do this by adding a "char str[0]" member at the end of the class, but that
// seems tricky with respect to alignment (what if the struct's size exceeds the location of str?)
std::string_view key() const {
const char *ptr = static_cast<const char *>(ptr_add_void(this, sizeof(*this)));
return std::string_view(ptr, key_length);
}
};
// TODO: Try to align this to the cache line size
template <size_t N>
struct VariableMetadataGroup {
// The summaries are packed together to facilitate SIMD operations
Summary summaries[N] = {};
KeyValueBucket *bucket_ptrs[N] = {};
VariableMetadataGroup() {
}
};
// static_assert(std::is_pod<VariableMetadataGroup<1>>::value); // For speed
using MetadataGroup = VariableMetadataGroup<kGroupSize>;
Hash(size_t capacity) : size_(0) {
metadata_groups_.resize(capacity / kGroupSize);
}
Hash() : Hash(0) {
}
std::vector<MetadataGroup> metadata_groups_;
// We can't, for example, use a deque here, because we need the whole thing to be contiguous
// for the sake of the strings
KeyValueBucket *storage_ = nullptr;
size_t storage_capacity_ = 0;
size_t storage_size_ = 0;
size_t size_;
// Power of 2:
/*size_t align_to(size_t number, size_t alignment) {
if (alignment == 0) {
throw std::invalid_argument("Alignment must be greater than 0");
}
return (number + alignment - 1) & ~(alignment - 1);
}*/
size_t align_to(size_t number, size_t alignment) {
if (alignment == 0) {
throw std::invalid_argument("Alignment must be greater than 0");
}
return ((number + alignment - 1) / alignment) * alignment;
}
size_t allocation_size(size_t key_length) {
size_t data_size = sizeof(KeyValueBucket) + key_length;
return align_to(data_size, alignof(KeyValueBucket));
}
void ensure_storage_capacity(size_t ensured_capacity) {
if (ensured_capacity <= storage_capacity_) {
return;
}
size_t new_capacity = align_to(std::bit_ceil(ensured_capacity), alignof(KeyValueBucket)); // Next power of 2
KeyValueBucket *new_storage = static_cast<KeyValueBucket *>(std::aligned_alloc(alignof(KeyValueBucket), new_capacity));
KeyValueBucket *curr_ptr = new_storage;
for (auto iter = begin(); iter != end(); iter++) {
auto &group = metadata_groups_[iter.index_ / kGroupSize];
size_t group_offset = iter.index_ % kGroupSize;
auto pair = *iter;
KeyValueBucket *old_bucket = group.bucket_ptrs[group_offset];
group.bucket_ptrs[group_offset] = curr_ptr;
new (curr_ptr) KeyValueBucket(old_bucket->hash_code, old_bucket->key_length, std::move(pair.second));
KeyValueBucket &new_bucket = *curr_ptr;
std::string_view view = old_bucket->key();
memcpy(const_cast<char *>(new_bucket.key().data()), view.data(), view.size());
old_bucket->~KeyValueBucket();
curr_ptr = ptr_add(curr_ptr, allocation_size(pair.first.size()));
}
// TODO: run destructors on old values, even after they've been moved from?
free(storage_);
storage_ = new_storage;
storage_capacity_ = new_capacity;
// storage_size_ shouldn't change, even if things got reordered
}
template <class... Args>
std::pair<Value &, bool> try_emplace(const std::string_view &key, Args&&... args){
// std::cout << "try emplace" << std::endl;
if (size_ == 0) {
rehash(16); // Must be a multiple of kGroupSize. Also, it has to be decently bigger than 8
// because of the way wraparound logic works. TODO: support low capacities
ensure_storage_capacity(1024);
} else if (size_ >= (size_t)(capacity() * 0.75)) {
rehash(capacity() * 2);
}
HashCode hash_code = std::hash<std::string_view>()(key);
auto search_result = next_bucket_or_empty(hash_code, key);
MetadataGroup &group = metadata_groups_[search_result.group_index];
auto &bucket = *group.bucket_ptrs[search_result.group_offset];
if (search_result.is_empty) {
// std::cout << "empty" << std::endl;
// Make sure not to get a pointer to the bucket before calling resize, because resize
// could invalidate the pointer
size_++;
size_t bucket_size = allocation_size(key.size());
ensure_storage_capacity(storage_size_ + bucket_size);
KeyValueBucket *bucket_ptr = ptr_add(storage_, storage_size_);
new (bucket_ptr) KeyValueBucket(hash_code, key.size(), std::forward<Args>(args)...);
group.bucket_ptrs[search_result.group_offset] = bucket_ptr;
group.summaries[search_result.group_offset] = summary_from_hash_code(hash_code);
storage_size_ += bucket_size;
memcpy(const_cast<char *>(bucket_ptr->key().data()), key.data(), key.length());
// The value is not initialized here, but oh well
return {bucket_ptr->value, true};
} else {
return {bucket.value, false};
}
}
/*inline size_t alignment_addition(const void *ptr, size_t size, size_t alignment) {
uintptr_t pointer = reinterpret_cast<uintptr_t>(ptr_add(ptr, size));
size_t remainder = pointer % alignment;
if (remainder == 0) {
return 0;
} else {
return alignment - remainder;
}
}*/
inline Summary summary_from_hash_code(HashCode hash_code) const {
Summary summary = (Summary)(hash_code << 1);
if (summary <= kTombstoneSummary) {
summary = kTombstoneSummary + 1;
}
return summary;
}
struct NextBucketOrEmptyResult {
size_t group_index;
size_t group_offset;
bool is_empty;
NextBucketOrEmptyResult(bool is_empty, size_t group_index, size_t group_offset) : group_index(group_index), group_offset(group_offset), is_empty(is_empty) {
}
};
inline size_t capacity() const {
return metadata_groups_.size() * kGroupSize;
}
NextBucketOrEmptyResult next_bucket_or_empty(HashCode hash_code, const std::string_view &key) const {
size_t bucket = hash_code & (capacity() - 1);
static_assert(kGroupSize == 8); // The shift here only works for 8
size_t group_index = bucket / kGroupSize;
size_t group_offset = bucket % kGroupSize;
size_t starting_group_index = group_index;
Summary target_summary = summary_from_hash_code(hash_code);
while (true) {
const MetadataGroup &group = metadata_groups_[group_index];
for (size_t i = group_offset; i < kGroupSize; i++) {
if (group.summaries[i] == kEmptySummary) {
return {true, group_index, i};
} else if (group.summaries[i] == target_summary) {
const KeyValueBucket &bucket = *group.bucket_ptrs[i];
if (hash_code == bucket.hash_code && bucket.key() == key) {
return {false, group_index, i};
}
}
}
group_offset = 0;
group_index++;
if (group_index == metadata_groups_.size()) {
group_index = 0;
}
if (group_index == starting_group_index) {
// Completely full map? Or at least almost full? It could be that all groups are full
// besides this one. But undesirable either way
throw std::runtime_error("fail");
}
}
}
// We could use try_emplace here, but that's wasteful since we don't want to rehash. We don't
// want to modify the storage, we just want to move around pointers
void rehash(size_t new_capacity) {
auto old_metadata_groups = metadata_groups_;
metadata_groups_.clear();
metadata_groups_.resize(new_capacity / kGroupSize);
// std::cout << "Rehashing to " << capacity() << "\n";
for (size_t group_index = 0; group_index < old_metadata_groups.size(); group_index++) {
auto &old_group = old_metadata_groups[group_index];
for (size_t group_offset = 0; group_offset < kGroupSize; group_offset++) {
Summary summary = old_group.summaries[group_offset];
if (summary != kEmptySummary && summary != kTombstoneSummary) {
KeyValueBucket &bucket = *old_group.bucket_ptrs[group_offset];
size_t new_index = bucket.hash_code % capacity();
size_t new_group_index = new_index / kGroupSize;
size_t new_group_offset = new_index % kGroupSize;
while (true) {
auto &new_group = metadata_groups_[new_group_index];
for (; new_group_offset < kGroupSize; new_group_offset++) {
if (new_group.summaries[new_group_offset] == kEmptySummary) {
new_group.summaries[new_group_offset] = summary_from_hash_code(bucket.hash_code);
new_group.bucket_ptrs[new_group_offset] = &bucket;
goto did_insert;
}
}
new_group_offset = 0;
new_group_index++;
if (new_group_index == capacity() / kGroupSize) {
new_group_index = 0;
}
}
did_insert:
;
}
}
}
/*for (const auto &group : metadata_groups_) {
for (size_t i = 0; i < kGroupSize; i++) {
std::cout << "Summary: " << (size_t)group.summaries[i] << ", " << group.bucket_ptrs[i] << std::endl;
}
}*/
}
~Hash() {
for (auto pair : *this) {
pair.second.~Value();
}
free(storage_);
}
};
// We don't actually destruct the whole thing, just the value
// static_assert(std::is_trivially_destructible<Hash<uint8_t>::KeyValueBucket>::value, "");
#endif // HASH_H
std::string random_string(int length) {
std::string string;
for (int i = 0; i < length; i++) {
string.push_back('a' + (rand() % 26));
}
return string;
}
void hard_assert(bool condition) {
if (!condition) {
abort();
}
}
double current_time() {
struct timespec ts = {};
clock_gettime(CLOCK_MONOTONIC, &ts);
return ts.tv_sec + ((double)ts.tv_nsec) / 1e9;
}
template <typename F>
void time_code(std::string name, F f) {
f(); // Warmup
double start_time = current_time();
double end_time = 0;
int runs_done = 0;
while (true) {
f();
runs_done++;
end_time = current_time();
if (end_time - start_time > 1.0) {
break;
}
}
double rate = runs_done / (end_time - start_time);
std::cout << name << " took " << rate << std::endl;
}
int main() {
std::vector<std::pair<std::string, int>> possibilities;
srand(0);
for (int i = 0; i < 1e6; i++) {
possibilities.emplace_back(random_string(25), rand());
}
std::vector<std::pair<std::string, int>> insertions;
for (size_t i = 0; i < possibilities.size() * 10; i++) {
insertions.emplace_back(possibilities[rand() % possibilities.size()]);
}
Hash h;
time_code("mine", [&]() {
for (const auto &[key, value] : insertions) {
h.try_emplace(std::string_view(key), value);
}
});
time_code("absl", [&]() {
absl::flat_hash_map<std::string, int> absl_map;
for (const auto &[key, value] : insertions) {
absl_map.try_emplace(key, value);
// absl_map.try_emplace(key, value);
}
hard_assert(absl_map.size() != 23943242); // Prevent optimizing stuff out
});
std::vector<std::pair<std::string, int>> results;
for (auto iter = h.begin(); iter != h.end(); iter++) {
const std::pair<const std::string_view, size_t &> result = *iter;
// std::cout << result.first << std::endl;
results.emplace_back(std::string(result.first), result.second);
}
std::sort(results.begin(), results.end(), [](auto &a, auto &b) {
return a.second < b.second;
});
std::unordered_map<std::string, int> expected_map;
for (const auto &[key, value] : insertions) {
expected_map.try_emplace(key, value);
}
std::vector<std::pair<std::string, int>> expected_vector;
for (const auto &[key, value] : expected_map) {
expected_vector.emplace_back(key, value);
}
std::sort(expected_vector.begin(), expected_vector.end(), [](auto &a, auto &b) {
return a.second < b.second;
});
hard_assert(results.size() == expected_vector.size());
for (size_t i = 0; i < results.size(); i++) {
hard_assert(results[i] == expected_vector[i]);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment