Created
October 17, 2024 17:45
-
-
Save michaeleisel/50e9890d6a489f7613cedd9d9f821972 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// This is a hash map implementation for maps with string keys only, designed to maximize performance. Similar in some ways to LLVM's StringMap | |
#include <iostream> | |
#include <unordered_map> | |
#include <memory> | |
#include <absl/container/flat_hash_map.h> | |
#ifndef HASH_H | |
#define HASH_H | |
#include <string_view> | |
#include <vector> | |
using Value = size_t; | |
// template <class Value, class Hasher = std::hash<std::string_view>> | |
// template <class Hasher = std::hash<std::string_view>> | |
class Hash { | |
public: | |
using HashCode = uint64_t; | |
using Summary = uint8_t; | |
static const size_t kGroupSize = 8; | |
static const Summary kTombstoneSummary = 1; | |
static const Summary kEmptySummary = 0; // 0-initialization leaves everything empty, as it should | |
template <typename T> | |
static T *ptr_add(T *ptr, size_t offset) { | |
return reinterpret_cast<T *>(reinterpret_cast<uint8_t *>(ptr) + offset); | |
} | |
static const void *ptr_add_void(const void *ptr, size_t offset) { | |
return static_cast<const void *>(static_cast<const uint8_t *>(ptr) + offset); | |
} | |
static void *ptr_add_void(void *ptr, size_t offset) { | |
return static_cast<void *>(static_cast<uint8_t *>(ptr) + offset); | |
} | |
class Iterator { | |
public: | |
size_t index_; | |
Hash &hash_; | |
Iterator(size_t index, Hash &hash) : index_(index), hash_(hash) { | |
skip_to_next_filled_or_end(); | |
} | |
void skip_to_next_filled_or_end() { | |
size_t end_index = hash_.capacity(); | |
while (index_ != end_index) { | |
auto &group = hash_.metadata_groups_[index_ / kGroupSize]; | |
auto offset = index_ % kGroupSize; | |
if (group.summaries[offset] != kEmptySummary && group.summaries[offset] != kTombstoneSummary) { | |
return; | |
} | |
index_++; | |
} | |
} | |
Iterator operator++() { | |
index_++; | |
skip_to_next_filled_or_end(); | |
return *this; | |
} | |
Iterator operator++(int) { | |
Iterator temp = *this; | |
++(*this); | |
return temp; | |
} | |
bool operator==(const Iterator& other) const { | |
return &(hash_) == &(other.hash_) && index_ == other.index_; | |
} | |
std::pair<const std::string_view, Value &> operator *() { | |
// Note: for some reason it had errors when I returned a reference to the string_view | |
auto &group = hash_.metadata_groups_[index_ / kGroupSize]; | |
auto &bucket = *group.bucket_ptrs[index_ % kGroupSize]; | |
return {bucket.key(), bucket.value}; | |
} | |
}; | |
Iterator begin() { | |
return {0, *this}; | |
} | |
Iterator end() { | |
return {capacity(), *this}; | |
} | |
struct KeyValueBucket { | |
HashCode hash_code; | |
size_t key_length; | |
Value value; | |
template <class... Args> | |
KeyValueBucket(HashCode hash_code, size_t key_length, Args&&... value_args) : hash_code(hash_code), key_length(key_length), value(std::forward<Args>(value_args)...) {} | |
// We could also do this by adding a "char str[0]" member at the end of the class, but that | |
// seems tricky with respect to alignment (what if the struct's size exceeds the location of str?) | |
std::string_view key() const { | |
const char *ptr = static_cast<const char *>(ptr_add_void(this, sizeof(*this))); | |
return std::string_view(ptr, key_length); | |
} | |
}; | |
// TODO: Try to align this to the cache line size | |
template <size_t N> | |
struct VariableMetadataGroup { | |
// The summaries are packed together to facilitate SIMD operations | |
Summary summaries[N] = {}; | |
KeyValueBucket *bucket_ptrs[N] = {}; | |
VariableMetadataGroup() { | |
} | |
}; | |
// static_assert(std::is_pod<VariableMetadataGroup<1>>::value); // For speed | |
using MetadataGroup = VariableMetadataGroup<kGroupSize>; | |
Hash(size_t capacity) : size_(0) { | |
metadata_groups_.resize(capacity / kGroupSize); | |
} | |
Hash() : Hash(0) { | |
} | |
std::vector<MetadataGroup> metadata_groups_; | |
// We can't, for example, use a deque here, because we need the whole thing to be contiguous | |
// for the sake of the strings | |
KeyValueBucket *storage_ = nullptr; | |
size_t storage_capacity_ = 0; | |
size_t storage_size_ = 0; | |
size_t size_; | |
// Power of 2: | |
/*size_t align_to(size_t number, size_t alignment) { | |
if (alignment == 0) { | |
throw std::invalid_argument("Alignment must be greater than 0"); | |
} | |
return (number + alignment - 1) & ~(alignment - 1); | |
}*/ | |
size_t align_to(size_t number, size_t alignment) { | |
if (alignment == 0) { | |
throw std::invalid_argument("Alignment must be greater than 0"); | |
} | |
return ((number + alignment - 1) / alignment) * alignment; | |
} | |
size_t allocation_size(size_t key_length) { | |
size_t data_size = sizeof(KeyValueBucket) + key_length; | |
return align_to(data_size, alignof(KeyValueBucket)); | |
} | |
void ensure_storage_capacity(size_t ensured_capacity) { | |
if (ensured_capacity <= storage_capacity_) { | |
return; | |
} | |
size_t new_capacity = align_to(std::bit_ceil(ensured_capacity), alignof(KeyValueBucket)); // Next power of 2 | |
KeyValueBucket *new_storage = static_cast<KeyValueBucket *>(std::aligned_alloc(alignof(KeyValueBucket), new_capacity)); | |
KeyValueBucket *curr_ptr = new_storage; | |
for (auto iter = begin(); iter != end(); iter++) { | |
auto &group = metadata_groups_[iter.index_ / kGroupSize]; | |
size_t group_offset = iter.index_ % kGroupSize; | |
auto pair = *iter; | |
KeyValueBucket *old_bucket = group.bucket_ptrs[group_offset]; | |
group.bucket_ptrs[group_offset] = curr_ptr; | |
new (curr_ptr) KeyValueBucket(old_bucket->hash_code, old_bucket->key_length, std::move(pair.second)); | |
KeyValueBucket &new_bucket = *curr_ptr; | |
std::string_view view = old_bucket->key(); | |
memcpy(const_cast<char *>(new_bucket.key().data()), view.data(), view.size()); | |
old_bucket->~KeyValueBucket(); | |
curr_ptr = ptr_add(curr_ptr, allocation_size(pair.first.size())); | |
} | |
// TODO: run destructors on old values, even after they've been moved from? | |
free(storage_); | |
storage_ = new_storage; | |
storage_capacity_ = new_capacity; | |
// storage_size_ shouldn't change, even if things got reordered | |
} | |
template <class... Args> | |
std::pair<Value &, bool> try_emplace(const std::string_view &key, Args&&... args){ | |
// std::cout << "try emplace" << std::endl; | |
if (size_ == 0) { | |
rehash(16); // Must be a multiple of kGroupSize. Also, it has to be decently bigger than 8 | |
// because of the way wraparound logic works. TODO: support low capacities | |
ensure_storage_capacity(1024); | |
} else if (size_ >= (size_t)(capacity() * 0.75)) { | |
rehash(capacity() * 2); | |
} | |
HashCode hash_code = std::hash<std::string_view>()(key); | |
auto search_result = next_bucket_or_empty(hash_code, key); | |
MetadataGroup &group = metadata_groups_[search_result.group_index]; | |
auto &bucket = *group.bucket_ptrs[search_result.group_offset]; | |
if (search_result.is_empty) { | |
// std::cout << "empty" << std::endl; | |
// Make sure not to get a pointer to the bucket before calling resize, because resize | |
// could invalidate the pointer | |
size_++; | |
size_t bucket_size = allocation_size(key.size()); | |
ensure_storage_capacity(storage_size_ + bucket_size); | |
KeyValueBucket *bucket_ptr = ptr_add(storage_, storage_size_); | |
new (bucket_ptr) KeyValueBucket(hash_code, key.size(), std::forward<Args>(args)...); | |
group.bucket_ptrs[search_result.group_offset] = bucket_ptr; | |
group.summaries[search_result.group_offset] = summary_from_hash_code(hash_code); | |
storage_size_ += bucket_size; | |
memcpy(const_cast<char *>(bucket_ptr->key().data()), key.data(), key.length()); | |
// The value is not initialized here, but oh well | |
return {bucket_ptr->value, true}; | |
} else { | |
return {bucket.value, false}; | |
} | |
} | |
/*inline size_t alignment_addition(const void *ptr, size_t size, size_t alignment) { | |
uintptr_t pointer = reinterpret_cast<uintptr_t>(ptr_add(ptr, size)); | |
size_t remainder = pointer % alignment; | |
if (remainder == 0) { | |
return 0; | |
} else { | |
return alignment - remainder; | |
} | |
}*/ | |
inline Summary summary_from_hash_code(HashCode hash_code) const { | |
Summary summary = (Summary)(hash_code << 1); | |
if (summary <= kTombstoneSummary) { | |
summary = kTombstoneSummary + 1; | |
} | |
return summary; | |
} | |
struct NextBucketOrEmptyResult { | |
size_t group_index; | |
size_t group_offset; | |
bool is_empty; | |
NextBucketOrEmptyResult(bool is_empty, size_t group_index, size_t group_offset) : group_index(group_index), group_offset(group_offset), is_empty(is_empty) { | |
} | |
}; | |
inline size_t capacity() const { | |
return metadata_groups_.size() * kGroupSize; | |
} | |
NextBucketOrEmptyResult next_bucket_or_empty(HashCode hash_code, const std::string_view &key) const { | |
size_t bucket = hash_code & (capacity() - 1); | |
static_assert(kGroupSize == 8); // The shift here only works for 8 | |
size_t group_index = bucket / kGroupSize; | |
size_t group_offset = bucket % kGroupSize; | |
size_t starting_group_index = group_index; | |
Summary target_summary = summary_from_hash_code(hash_code); | |
while (true) { | |
const MetadataGroup &group = metadata_groups_[group_index]; | |
for (size_t i = group_offset; i < kGroupSize; i++) { | |
if (group.summaries[i] == kEmptySummary) { | |
return {true, group_index, i}; | |
} else if (group.summaries[i] == target_summary) { | |
const KeyValueBucket &bucket = *group.bucket_ptrs[i]; | |
if (hash_code == bucket.hash_code && bucket.key() == key) { | |
return {false, group_index, i}; | |
} | |
} | |
} | |
group_offset = 0; | |
group_index++; | |
if (group_index == metadata_groups_.size()) { | |
group_index = 0; | |
} | |
if (group_index == starting_group_index) { | |
// Completely full map? Or at least almost full? It could be that all groups are full | |
// besides this one. But undesirable either way | |
throw std::runtime_error("fail"); | |
} | |
} | |
} | |
// We could use try_emplace here, but that's wasteful since we don't want to rehash. We don't | |
// want to modify the storage, we just want to move around pointers | |
void rehash(size_t new_capacity) { | |
auto old_metadata_groups = metadata_groups_; | |
metadata_groups_.clear(); | |
metadata_groups_.resize(new_capacity / kGroupSize); | |
// std::cout << "Rehashing to " << capacity() << "\n"; | |
for (size_t group_index = 0; group_index < old_metadata_groups.size(); group_index++) { | |
auto &old_group = old_metadata_groups[group_index]; | |
for (size_t group_offset = 0; group_offset < kGroupSize; group_offset++) { | |
Summary summary = old_group.summaries[group_offset]; | |
if (summary != kEmptySummary && summary != kTombstoneSummary) { | |
KeyValueBucket &bucket = *old_group.bucket_ptrs[group_offset]; | |
size_t new_index = bucket.hash_code % capacity(); | |
size_t new_group_index = new_index / kGroupSize; | |
size_t new_group_offset = new_index % kGroupSize; | |
while (true) { | |
auto &new_group = metadata_groups_[new_group_index]; | |
for (; new_group_offset < kGroupSize; new_group_offset++) { | |
if (new_group.summaries[new_group_offset] == kEmptySummary) { | |
new_group.summaries[new_group_offset] = summary_from_hash_code(bucket.hash_code); | |
new_group.bucket_ptrs[new_group_offset] = &bucket; | |
goto did_insert; | |
} | |
} | |
new_group_offset = 0; | |
new_group_index++; | |
if (new_group_index == capacity() / kGroupSize) { | |
new_group_index = 0; | |
} | |
} | |
did_insert: | |
; | |
} | |
} | |
} | |
/*for (const auto &group : metadata_groups_) { | |
for (size_t i = 0; i < kGroupSize; i++) { | |
std::cout << "Summary: " << (size_t)group.summaries[i] << ", " << group.bucket_ptrs[i] << std::endl; | |
} | |
}*/ | |
} | |
~Hash() { | |
for (auto pair : *this) { | |
pair.second.~Value(); | |
} | |
free(storage_); | |
} | |
}; | |
// We don't actually destruct the whole thing, just the value | |
// static_assert(std::is_trivially_destructible<Hash<uint8_t>::KeyValueBucket>::value, ""); | |
#endif // HASH_H | |
std::string random_string(int length) { | |
std::string string; | |
for (int i = 0; i < length; i++) { | |
string.push_back('a' + (rand() % 26)); | |
} | |
return string; | |
} | |
void hard_assert(bool condition) { | |
if (!condition) { | |
abort(); | |
} | |
} | |
double current_time() { | |
struct timespec ts = {}; | |
clock_gettime(CLOCK_MONOTONIC, &ts); | |
return ts.tv_sec + ((double)ts.tv_nsec) / 1e9; | |
} | |
template <typename F> | |
void time_code(std::string name, F f) { | |
f(); // Warmup | |
double start_time = current_time(); | |
double end_time = 0; | |
int runs_done = 0; | |
while (true) { | |
f(); | |
runs_done++; | |
end_time = current_time(); | |
if (end_time - start_time > 1.0) { | |
break; | |
} | |
} | |
double rate = runs_done / (end_time - start_time); | |
std::cout << name << " took " << rate << std::endl; | |
} | |
int main() { | |
std::vector<std::pair<std::string, int>> possibilities; | |
srand(0); | |
for (int i = 0; i < 1e6; i++) { | |
possibilities.emplace_back(random_string(25), rand()); | |
} | |
std::vector<std::pair<std::string, int>> insertions; | |
for (size_t i = 0; i < possibilities.size() * 10; i++) { | |
insertions.emplace_back(possibilities[rand() % possibilities.size()]); | |
} | |
Hash h; | |
time_code("mine", [&]() { | |
for (const auto &[key, value] : insertions) { | |
h.try_emplace(std::string_view(key), value); | |
} | |
}); | |
time_code("absl", [&]() { | |
absl::flat_hash_map<std::string, int> absl_map; | |
for (const auto &[key, value] : insertions) { | |
absl_map.try_emplace(key, value); | |
// absl_map.try_emplace(key, value); | |
} | |
hard_assert(absl_map.size() != 23943242); // Prevent optimizing stuff out | |
}); | |
std::vector<std::pair<std::string, int>> results; | |
for (auto iter = h.begin(); iter != h.end(); iter++) { | |
const std::pair<const std::string_view, size_t &> result = *iter; | |
// std::cout << result.first << std::endl; | |
results.emplace_back(std::string(result.first), result.second); | |
} | |
std::sort(results.begin(), results.end(), [](auto &a, auto &b) { | |
return a.second < b.second; | |
}); | |
std::unordered_map<std::string, int> expected_map; | |
for (const auto &[key, value] : insertions) { | |
expected_map.try_emplace(key, value); | |
} | |
std::vector<std::pair<std::string, int>> expected_vector; | |
for (const auto &[key, value] : expected_map) { | |
expected_vector.emplace_back(key, value); | |
} | |
std::sort(expected_vector.begin(), expected_vector.end(), [](auto &a, auto &b) { | |
return a.second < b.second; | |
}); | |
hard_assert(results.size() == expected_vector.size()); | |
for (size_t i = 0; i < results.size(); i++) { | |
hard_assert(results[i] == expected_vector[i]); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment