Last active
June 12, 2025 21:12
-
-
Save jweinst1/3641ff3317dd622237c21b401e595c69 to your computer and use it in GitHub Desktop.
sorted trie for fast integer lookups
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <set> | |
#include <vector> | |
#include <thread> | |
#include <mutex> | |
#include <random> | |
#include <chrono> | |
#include <atomic> | |
constexpr int BITS_PER_LEVEL = 6; | |
constexpr int BRANCHES = 1 << BITS_PER_LEVEL; | |
constexpr int LEVELS = (32 + BITS_PER_LEVEL - 1) / BITS_PER_LEVEL; | |
struct SortNode { | |
std::atomic<uint64_t> sortFinder{0}; | |
void* data = nullptr; | |
SortNode* children[BRANCHES] = {nullptr}; | |
}; | |
class SortTrie { | |
public: | |
SortTrie() : root(new SortNode()) {} | |
~SortTrie() { clear(root); } | |
void insert(uint32_t key) { | |
SortNode* node = root; | |
for (int i = LEVELS - 1; i >= 0; --i) { | |
int shift = i * BITS_PER_LEVEL; | |
int index = (key >> shift) & (BRANCHES - 1); | |
node->sortFinder.fetch_or(1ULL << index, std::memory_order_relaxed); | |
SortNode*& child = node->children[index]; | |
if (!child) { | |
SortNode* new_node = new SortNode(); | |
SortNode* expected = nullptr; | |
if (!std::atomic_compare_exchange_strong( | |
reinterpret_cast<std::atomic<SortNode*>*>(&child), | |
&expected, new_node)) { | |
delete new_node; | |
} | |
} | |
node = node->children[index]; | |
} | |
node->data = reinterpret_cast<void*>(static_cast<uintptr_t>(key)); | |
} | |
bool find(uint32_t key) const { | |
SortNode* node = root; | |
for (int i = LEVELS - 1; i >= 0; --i) { | |
int shift = i * BITS_PER_LEVEL; | |
int index = (key >> shift) & (BRANCHES - 1); | |
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed); | |
if (!(mask & (1ULL << index))) | |
return false; | |
node = node->children[index]; | |
if (!node) return false; | |
} | |
return reinterpret_cast<uintptr_t>(node->data) == key; | |
} | |
void collect_sorted(std::vector<uint32_t>& out) const { | |
in_order(root, 0, 0, out); | |
} | |
private: | |
SortNode* root; | |
void in_order(SortNode* node, uint32_t prefix, int level, std::vector<uint32_t>& out) const { | |
if (!node) return; | |
if (level == LEVELS) { | |
out.push_back(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(node->data))); | |
return; | |
} | |
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed); | |
while (mask) { | |
int i = __builtin_ctzll(mask); | |
mask &= ~(1ULL << i); | |
uint32_t next_prefix = (prefix << BITS_PER_LEVEL) | i; | |
in_order(node->children[i], next_prefix, level + 1, out); | |
} | |
} | |
void clear(SortNode* node) { | |
if (!node) return; | |
for (int i = 0; i < BRANCHES; ++i) | |
clear(node->children[i]); | |
delete node; | |
} | |
}; | |
void bench_multithreaded(size_t N, int num_threads) { | |
std::vector<uint32_t> data(N); | |
std::mt19937 rng(12345); | |
for (auto& x : data) x = rng(); | |
std::vector<std::vector<uint32_t>> chunks(num_threads); | |
for (size_t i = 0; i < N; ++i) | |
chunks[i % num_threads].push_back(data[i]); | |
// std::set with mutex | |
std::set<uint32_t> s; | |
std::mutex m; | |
auto set_start = std::chrono::high_resolution_clock::now(); | |
std::vector<std::thread> threads_set; | |
for (int t = 0; t < num_threads; ++t) { | |
threads_set.emplace_back([&s, &m, &chunks, t]() { | |
for (uint32_t x : chunks[t]) { | |
std::lock_guard<std::mutex> lock(m); | |
s.insert(x); | |
} | |
}); | |
} | |
for (auto& th : threads_set) th.join(); | |
auto set_end = std::chrono::high_resolution_clock::now(); | |
double set_insert_time = std::chrono::duration<double, std::milli>(set_end - set_start).count(); | |
// SortTrie insert | |
SortTrie trie; | |
auto trie_start = std::chrono::high_resolution_clock::now(); | |
std::vector<std::thread> threads_trie; | |
for (int t = 0; t < num_threads; ++t) { | |
threads_trie.emplace_back([&trie, &chunks, t]() { | |
for (uint32_t x : chunks[t]) { | |
trie.insert(x); | |
} | |
}); | |
} | |
for (auto& th : threads_trie) th.join(); | |
auto trie_end = std::chrono::high_resolution_clock::now(); | |
double trie_insert_time = std::chrono::duration<double, std::milli>(trie_end - trie_start).count(); | |
std::cout << "=== Multithreaded Benchmark with N = " << N << ", T = " << num_threads << " ===\n"; | |
std::cout << "std::set insert time: " << set_insert_time << " ms\n"; | |
std::cout << "SortTrie insert time: " << trie_insert_time << " ms\n"; | |
// --- Multithreaded lookup --- | |
// std::set lookup | |
auto set_lookup_start = std::chrono::high_resolution_clock::now(); | |
std::atomic<size_t> set_found{0}; | |
std::vector<std::thread> threads_set_lookup; | |
for (int t = 0; t < num_threads; ++t) { | |
threads_set_lookup.emplace_back([&s, &m, &chunks, &set_found, t]() { | |
for (uint32_t x : chunks[t]) { | |
std::lock_guard<std::mutex> lock(m); | |
if (s.find(x) != s.end()) { | |
set_found.fetch_add(1, std::memory_order_relaxed); | |
} | |
} | |
}); | |
} | |
for (auto& th : threads_set_lookup) th.join(); | |
auto set_lookup_end = std::chrono::high_resolution_clock::now(); | |
double set_lookup_time = std::chrono::duration<double, std::milli>(set_lookup_end - set_lookup_start).count(); | |
// SortTrie lookup | |
auto trie_lookup_start = std::chrono::high_resolution_clock::now(); | |
std::atomic<size_t> trie_found{0}; | |
std::vector<std::thread> threads_trie_lookup; | |
for (int t = 0; t < num_threads; ++t) { | |
threads_trie_lookup.emplace_back([&trie, &chunks, &trie_found, t]() { | |
for (uint32_t x : chunks[t]) { | |
if (trie.find(x)) { | |
trie_found.fetch_add(1, std::memory_order_relaxed); | |
} | |
} | |
}); | |
} | |
for (auto& th : threads_trie_lookup) th.join(); | |
auto trie_lookup_end = std::chrono::high_resolution_clock::now(); | |
double trie_lookup_time = std::chrono::duration<double, std::milli>(trie_lookup_end - trie_lookup_start).count(); | |
std::cout << "std::set lookup time: " << set_lookup_time << " ms (found " << set_found.load() << ")\n"; | |
std::cout << "SortTrie lookup time: " << trie_lookup_time << " ms (found " << trie_found.load() << ")\n"; | |
} | |
int main() { | |
bench_multithreaded(1'000'000, 8); | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <set> | |
#include <vector> | |
#include <thread> | |
#include <mutex> | |
#include <random> | |
#include <chrono> | |
#include <atomic> | |
constexpr int BITS_PER_LEVEL = 6; | |
constexpr int BRANCHES = 1 << BITS_PER_LEVEL; | |
constexpr int LEVELS = (64 + BITS_PER_LEVEL - 1) / BITS_PER_LEVEL; // 11 levels | |
struct SortNode { | |
std::atomic<uint64_t> sortFinder{0}; | |
void* data = nullptr; | |
SortNode* children[BRANCHES] = {nullptr}; | |
}; | |
class SortTrie { | |
public: | |
SortTrie() : root(new SortNode()) {} | |
~SortTrie() { clear(root); } | |
void insert(uint64_t key) { | |
SortNode* node = root; | |
for (int i = LEVELS - 1; i >= 0; --i) { | |
int shift = i * BITS_PER_LEVEL; | |
int index = (key >> shift) & (BRANCHES - 1); | |
node->sortFinder.fetch_or(1ULL << index, std::memory_order_relaxed); | |
SortNode*& child = node->children[index]; | |
if (!child) { | |
SortNode* new_node = new SortNode(); | |
SortNode* expected = nullptr; | |
if (!std::atomic_compare_exchange_strong( | |
reinterpret_cast<std::atomic<SortNode*>*>(&child), | |
&expected, new_node)) { | |
delete new_node; | |
} | |
} | |
node = node->children[index]; | |
} | |
node->data = reinterpret_cast<void*>(static_cast<uintptr_t>(key)); | |
} | |
bool find(uint64_t key) const { | |
SortNode* node = root; | |
for (int i = LEVELS - 1; i >= 0; --i) { | |
int shift = i * BITS_PER_LEVEL; | |
int index = (key >> shift) & (BRANCHES - 1); | |
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed); | |
if (!(mask & (1ULL << index))) | |
return false; | |
node = node->children[index]; | |
if (!node) return false; | |
} | |
return reinterpret_cast<uintptr_t>(node->data) == key; | |
} | |
void collect_sorted(std::vector<uint64_t>& out) const { | |
in_order(root, 0, 0, out); | |
} | |
private: | |
SortNode* root; | |
void in_order(SortNode* node, uint64_t prefix, int level, std::vector<uint64_t>& out) const { | |
if (!node) return; | |
if (level == LEVELS) { | |
out.push_back(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(node->data))); | |
return; | |
} | |
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed); | |
while (mask) { | |
int i = __builtin_ctzll(mask); | |
mask &= ~(1ULL << i); | |
uint64_t next_prefix = (prefix << BITS_PER_LEVEL) | i; | |
in_order(node->children[i], next_prefix, level + 1, out); | |
} | |
} | |
void clear(SortNode* node) { | |
if (!node) return; | |
for (int i = 0; i < BRANCHES; ++i) | |
clear(node->children[i]); | |
delete node; | |
} | |
}; | |
void bench_multithreaded(size_t N, int num_threads) { | |
std::vector<uint64_t> data(N); | |
std::mt19937_64 rng(12345); | |
for (auto& x : data) x = rng(); | |
std::vector<std::vector<uint64_t>> chunks(num_threads); | |
for (size_t i = 0; i < N; ++i) | |
chunks[i % num_threads].push_back(data[i]); | |
// std::set with mutex | |
std::set<uint64_t> s; | |
std::mutex m; | |
auto set_start = std::chrono::high_resolution_clock::now(); | |
std::vector<std::thread> threads_set; | |
for (int t = 0; t < num_threads; ++t) { | |
threads_set.emplace_back([&s, &m, &chunks, t]() { | |
for (uint64_t x : chunks[t]) { | |
std::lock_guard<std::mutex> lock(m); | |
s.insert(x); | |
} | |
}); | |
} | |
for (auto& th : threads_set) th.join(); | |
auto set_end = std::chrono::high_resolution_clock::now(); | |
double set_insert_time = std::chrono::duration<double, std::milli>(set_end - set_start).count(); | |
// SortTrie insert | |
SortTrie trie; | |
auto trie_start = std::chrono::high_resolution_clock::now(); | |
std::vector<std::thread> threads_trie; | |
for (int t = 0; t < num_threads; ++t) { | |
threads_trie.emplace_back([&trie, &chunks, t]() { | |
for (uint64_t x : chunks[t]) { | |
trie.insert(x); | |
} | |
}); | |
} | |
for (auto& th : threads_trie) th.join(); | |
auto trie_end = std::chrono::high_resolution_clock::now(); | |
double trie_insert_time = std::chrono::duration<double, std::milli>(trie_end - trie_start).count(); | |
std::cout << "=== Multithreaded Benchmark with N = " << N << ", T = " << num_threads << " (64-bit keys) ===\n"; | |
std::cout << "std::set insert time: " << set_insert_time << " ms\n"; | |
std::cout << "SortTrie insert time: " << trie_insert_time << " ms\n"; | |
// --- Multithreaded lookup --- | |
// std::set lookup | |
auto set_lookup_start = std::chrono::high_resolution_clock::now(); | |
std::atomic<size_t> set_found{0}; | |
std::vector<std::thread> threads_set_lookup; | |
for (int t = 0; t < num_threads; ++t) { | |
threads_set_lookup.emplace_back([&s, &m, &chunks, &set_found, t]() { | |
for (uint64_t x : chunks[t]) { | |
std::lock_guard<std::mutex> lock(m); | |
if (s.find(x) != s.end()) { | |
set_found.fetch_add(1, std::memory_order_relaxed); | |
} | |
} | |
}); | |
} | |
for (auto& th : threads_set_lookup) th.join(); | |
auto set_lookup_end = std::chrono::high_resolution_clock::now(); | |
double set_lookup_time = std::chrono::duration<double, std::milli>(set_lookup_end - set_lookup_start).count(); | |
// SortTrie lookup | |
auto trie_lookup_start = std::chrono::high_resolution_clock::now(); | |
std::atomic<size_t> trie_found{0}; | |
std::vector<std::thread> threads_trie_lookup; | |
for (int t = 0; t < num_threads; ++t) { | |
threads_trie_lookup.emplace_back([&trie, &chunks, &trie_found, t]() { | |
for (uint64_t x : chunks[t]) { | |
if (trie.find(x)) { | |
trie_found.fetch_add(1, std::memory_order_relaxed); | |
} | |
} | |
}); | |
} | |
for (auto& th : threads_trie_lookup) th.join(); | |
auto trie_lookup_end = std::chrono::high_resolution_clock::now(); | |
double trie_lookup_time = std::chrono::duration<double, std::milli>(trie_lookup_end - trie_lookup_start).count(); | |
std::cout << "std::set lookup time: " << set_lookup_time << " ms (found " << set_found.load() << ")\n"; | |
std::cout << "SortTrie lookup time: " << trie_lookup_time << " ms (found " << trie_found.load() << ")\n"; | |
} | |
int main() { | |
bench_multithreaded(1'000'000, 8); | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <atomic> | |
#include <cstdint> | |
#include <cstring> | |
#include <vector> | |
#include <mutex> | |
#include <shared_mutex> | |
template <typename KeyT, int BitsPerLevel = 6> | |
class SortTrie { | |
static_assert(std::is_unsigned<KeyT>::value, "KeyT must be unsigned"); | |
static constexpr int BRANCHES = 1 << BitsPerLevel; | |
static constexpr int LEVELS = (sizeof(KeyT) * 8 + BitsPerLevel - 1) / BitsPerLevel; | |
struct SortNode { | |
std::atomic<uint64_t> sortFinder{0}; | |
std::atomic<void*> children[BRANCHES]; | |
SortNode() { | |
for (int i = 0; i < BRANCHES; ++i) children[i] = nullptr; | |
} | |
}; | |
SortNode* root; | |
public: | |
SortTrie() : root(new SortNode()) {} | |
~SortTrie() { clear(root, 0); } | |
void insert(KeyT key, void* data) { | |
SortNode* node = root; | |
for (int level = 0; level < LEVELS - 1; ++level) { | |
int shift = (LEVELS - 1 - level) * BitsPerLevel; | |
int index = (key >> shift) & (BRANCHES - 1); | |
node->sortFinder.fetch_or(1ULL << index, std::memory_order_relaxed); | |
void* child = node->children[index].load(std::memory_order_acquire); | |
if (!child) { | |
auto* new_node = new SortNode(); | |
if (!node->children[index].compare_exchange_strong(child, new_node)) | |
delete new_node; // another thread inserted | |
else | |
child = new_node; | |
} | |
node = reinterpret_cast<SortNode*>(child); | |
} | |
// Leaf level | |
int final_index = key & (BRANCHES - 1); | |
node->sortFinder.fetch_or(1ULL << final_index, std::memory_order_relaxed); | |
node->children[final_index].store(data, std::memory_order_release); | |
} | |
void* find(KeyT key) const { | |
SortNode* node = root; | |
for (int level = 0; level < LEVELS - 1; ++level) { | |
int shift = (LEVELS - 1 - level) * BitsPerLevel; | |
int index = (key >> shift) & (BRANCHES - 1); | |
if (!(node->sortFinder.load(std::memory_order_relaxed) & (1ULL << index))) | |
return nullptr; | |
void* child = node->children[index].load(std::memory_order_acquire); | |
if (!child) return nullptr; | |
node = reinterpret_cast<SortNode*>(child); | |
} | |
int final_index = key & (BRANCHES - 1); | |
if (!(node->sortFinder.load(std::memory_order_relaxed) & (1ULL << final_index))) | |
return nullptr; | |
return node->children[final_index].load(std::memory_order_acquire); | |
} | |
void collect_sorted(std::vector<KeyT>& out) const { | |
in_order(root, 0, 0, out); | |
} | |
private: | |
void in_order(SortNode* node, KeyT prefix, int level, std::vector<KeyT>& out) const { | |
if (!node) return; | |
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed); | |
while (mask) { | |
int i = __builtin_ctzll(mask); | |
mask &= ~(1ULL << i); | |
KeyT next_prefix = (prefix << BitsPerLevel) | i; | |
void* child = node->children[i].load(std::memory_order_acquire); | |
if (level == LEVELS - 1) { | |
out.push_back(next_prefix); // reached leaf, where child is actual data | |
} else { | |
in_order(reinterpret_cast<SortNode*>(child), next_prefix, level + 1, out); | |
} | |
} | |
} | |
void clear(SortNode* node, int level) { | |
if (!node) return; | |
for (int i = 0; i < BRANCHES; ++i) { | |
void* ptr = node->children[i].load(); | |
if (ptr) { | |
if (level < LEVELS - 1) | |
clear(reinterpret_cast<SortNode*>(ptr), level + 1); | |
// else it's a data pointer and we don't own/delete it | |
} | |
} | |
delete node; | |
} | |
}; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <atomic> | |
#include <cstdint> | |
#include <cassert> | |
#include <memory> | |
template <typename KeyType = uint32_t, int BITS_PER_LEVEL = 6> | |
class SortTrie { | |
static_assert(std::is_unsigned<KeyType>::value, "KeyType must be unsigned integer"); | |
static constexpr int BRANCHES = 1 << BITS_PER_LEVEL; | |
static constexpr int LEVELS = (sizeof(KeyType) * 8 + BITS_PER_LEVEL - 1) / BITS_PER_LEVEL; | |
struct SortNode { | |
std::atomic<uint64_t> sortFinder{0}; | |
void* data = nullptr; | |
SortNode* children[BRANCHES] = {nullptr}; | |
~SortNode() { | |
for (int i = 0; i < BRANCHES; ++i) { | |
delete children[i]; | |
} | |
} | |
}; | |
public: | |
SortTrie() : root(new SortNode()) {} | |
~SortTrie() { delete root; } | |
void insert(KeyType key) { | |
SortNode* node = root; | |
for (int i = LEVELS - 1; i >= 0; --i) { | |
int shift = i * BITS_PER_LEVEL; | |
int index = static_cast<int>((key >> shift) & (BRANCHES - 1)); | |
node->sortFinder.fetch_or(1ULL << index, std::memory_order_relaxed); | |
SortNode*& child = node->children[index]; | |
if (!child) { | |
SortNode* new_node = new SortNode(); | |
SortNode* expected = nullptr; | |
if (!std::atomic_compare_exchange_strong( | |
reinterpret_cast<std::atomic<SortNode*>*>(&child), | |
&expected, new_node)) { | |
delete new_node; | |
} | |
} | |
node = node->children[index]; | |
} | |
node->data = reinterpret_cast<void*>(static_cast<uintptr_t>(key)); | |
} | |
bool find(KeyType key) const { | |
SortNode* node = root; | |
for (int i = LEVELS - 1; i >= 0; --i) { | |
int shift = i * BITS_PER_LEVEL; | |
int index = static_cast<int>((key >> shift) & (BRANCHES - 1)); | |
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed); | |
if (!(mask & (1ULL << index))) return false; | |
node = node->children[index]; | |
if (!node) return false; | |
} | |
return reinterpret_cast<uintptr_t>(node->data) == key; | |
} | |
void collect_sorted(std::vector<KeyType>& out) const { | |
in_order(root, 0, 0, out); | |
} | |
private: | |
SortNode* root; | |
void in_order(SortNode* node, KeyType prefix, int level, std::vector<KeyType>& out) const { | |
if (!node) return; | |
if (level == LEVELS) { | |
out.push_back(static_cast<KeyType>(reinterpret_cast<uintptr_t>(node->data))); | |
return; | |
} | |
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed); | |
while (mask) { | |
int i = __builtin_ctzll(mask); | |
mask &= ~(1ULL << i); | |
KeyType next_prefix = (prefix << BITS_PER_LEVEL) | static_cast<KeyType>(i); | |
in_order(node->children[i], next_prefix, level + 1, out); | |
} | |
} | |
}; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <atomic> | |
#include <cstdint> | |
#include <cassert> | |
#include <memory> | |
template <typename KeyType = uint32_t, int BITS_PER_LEVEL = 6> | |
class SortTrie { | |
static_assert(std::is_unsigned<KeyType>::value, "KeyType must be unsigned integer"); | |
static constexpr int BRANCHES = 1 << BITS_PER_LEVEL; | |
static constexpr int LEVELS = (sizeof(KeyType) * 8 + BITS_PER_LEVEL - 1) / BITS_PER_LEVEL; | |
struct SortNode { | |
std::atomic<uint64_t> sortFinder{0}; | |
void* data = nullptr; | |
SortNode* children[BRANCHES] = {nullptr}; | |
~SortNode() { | |
for (int i = 0; i < BRANCHES; ++i) { | |
delete children[i]; | |
} | |
} | |
}; | |
public: | |
SortTrie() : root(new SortNode()) {} | |
~SortTrie() { delete root; } | |
void insert(KeyType key) { | |
SortNode* node = root; | |
for (int i = LEVELS - 1; i >= 0; --i) { | |
int shift = i * BITS_PER_LEVEL; | |
int index = static_cast<int>((key >> shift) & (BRANCHES - 1)); | |
node->sortFinder.fetch_or(1ULL << index, std::memory_order_relaxed); | |
SortNode*& child = node->children[index]; | |
if (!child) { | |
SortNode* new_node = new SortNode(); | |
SortNode* expected = nullptr; | |
if (!std::atomic_compare_exchange_strong( | |
reinterpret_cast<std::atomic<SortNode*>*>(&child), | |
&expected, new_node)) { | |
delete new_node; | |
} | |
} | |
node = node->children[index]; | |
} | |
node->data = reinterpret_cast<void*>(static_cast<uintptr_t>(key)); | |
} | |
bool find(KeyType key) const { | |
SortNode* node = root; | |
for (int i = LEVELS - 1; i >= 0; --i) { | |
int shift = i * BITS_PER_LEVEL; | |
int index = static_cast<int>((key >> shift) & (BRANCHES - 1)); | |
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed); | |
if (!(mask & (1ULL << index))) return false; | |
node = node->children[index]; | |
if (!node) return false; | |
} | |
return reinterpret_cast<uintptr_t>(node->data) == key; | |
} | |
void collect_sorted(std::vector<KeyType>& out) const { | |
in_order(root, 0, 0, out); | |
} | |
// New query method: collects keys between low and high inclusive | |
void query_range(KeyType low, KeyType high, std::vector<KeyType>& out) const { | |
query_range_helper(root, 0, 0, low, high, out); | |
} | |
private: | |
SortNode* root; | |
void in_order(SortNode* node, KeyType prefix, int level, std::vector<KeyType>& out) const { | |
if (!node) return; | |
if (level == LEVELS) { | |
out.push_back(static_cast<KeyType>(reinterpret_cast<uintptr_t>(node->data))); | |
return; | |
} | |
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed); | |
while (mask) { | |
int i = __builtin_ctzll(mask); | |
mask &= ~(1ULL << i); | |
KeyType next_prefix = (prefix << BITS_PER_LEVEL) | static_cast<KeyType>(i); | |
in_order(node->children[i], next_prefix, level + 1, out); | |
} | |
} | |
void query_range_helper(SortNode* node, KeyType prefix, int level, KeyType low, KeyType high, std::vector<KeyType>& out) const { | |
if (!node) return; | |
if (level == LEVELS) { | |
KeyType val = static_cast<KeyType>(reinterpret_cast<uintptr_t>(node->data)); | |
if (val >= low && val <= high) { | |
out.push_back(val); | |
} | |
return; | |
} | |
// Calculate the bit range that this node covers for keys | |
int shift = (LEVELS - level - 1) * BITS_PER_LEVEL; | |
KeyType node_range_size = KeyType(1) << shift; | |
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed); | |
while (mask) { | |
int i = __builtin_ctzll(mask); | |
mask &= ~(1ULL << i); | |
KeyType child_min = (prefix << BITS_PER_LEVEL) | static_cast<KeyType>(i); | |
child_min <<= shift; | |
KeyType child_max = child_min + node_range_size - 1; | |
// Prune subtree if out of range | |
if (child_max < low || child_min > high) { | |
continue; | |
} | |
// Otherwise recurse | |
KeyType next_prefix = (prefix << BITS_PER_LEVEL) | static_cast<KeyType>(i); | |
query_range_helper(node->children[i], next_prefix, level + 1, low, high, out); | |
} | |
} | |
}; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment