Skip to content

Instantly share code, notes, and snippets.

@jweinst1
Last active June 12, 2025 21:12
Show Gist options
  • Save jweinst1/3641ff3317dd622237c21b401e595c69 to your computer and use it in GitHub Desktop.
Save jweinst1/3641ff3317dd622237c21b401e595c69 to your computer and use it in GitHub Desktop.
sorted trie for fast integer lookups
#include <iostream>
#include <set>
#include <vector>
#include <thread>
#include <mutex>
#include <random>
#include <chrono>
#include <atomic>
constexpr int BITS_PER_LEVEL = 6;
constexpr int BRANCHES = 1 << BITS_PER_LEVEL;
constexpr int LEVELS = (32 + BITS_PER_LEVEL - 1) / BITS_PER_LEVEL;
struct SortNode {
std::atomic<uint64_t> sortFinder{0};
void* data = nullptr;
SortNode* children[BRANCHES] = {nullptr};
};
class SortTrie {
public:
SortTrie() : root(new SortNode()) {}
~SortTrie() { clear(root); }
void insert(uint32_t key) {
SortNode* node = root;
for (int i = LEVELS - 1; i >= 0; --i) {
int shift = i * BITS_PER_LEVEL;
int index = (key >> shift) & (BRANCHES - 1);
node->sortFinder.fetch_or(1ULL << index, std::memory_order_relaxed);
SortNode*& child = node->children[index];
if (!child) {
SortNode* new_node = new SortNode();
SortNode* expected = nullptr;
if (!std::atomic_compare_exchange_strong(
reinterpret_cast<std::atomic<SortNode*>*>(&child),
&expected, new_node)) {
delete new_node;
}
}
node = node->children[index];
}
node->data = reinterpret_cast<void*>(static_cast<uintptr_t>(key));
}
bool find(uint32_t key) const {
SortNode* node = root;
for (int i = LEVELS - 1; i >= 0; --i) {
int shift = i * BITS_PER_LEVEL;
int index = (key >> shift) & (BRANCHES - 1);
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed);
if (!(mask & (1ULL << index)))
return false;
node = node->children[index];
if (!node) return false;
}
return reinterpret_cast<uintptr_t>(node->data) == key;
}
void collect_sorted(std::vector<uint32_t>& out) const {
in_order(root, 0, 0, out);
}
private:
SortNode* root;
void in_order(SortNode* node, uint32_t prefix, int level, std::vector<uint32_t>& out) const {
if (!node) return;
if (level == LEVELS) {
out.push_back(static_cast<uint32_t>(reinterpret_cast<uintptr_t>(node->data)));
return;
}
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed);
while (mask) {
int i = __builtin_ctzll(mask);
mask &= ~(1ULL << i);
uint32_t next_prefix = (prefix << BITS_PER_LEVEL) | i;
in_order(node->children[i], next_prefix, level + 1, out);
}
}
void clear(SortNode* node) {
if (!node) return;
for (int i = 0; i < BRANCHES; ++i)
clear(node->children[i]);
delete node;
}
};
void bench_multithreaded(size_t N, int num_threads) {
std::vector<uint32_t> data(N);
std::mt19937 rng(12345);
for (auto& x : data) x = rng();
std::vector<std::vector<uint32_t>> chunks(num_threads);
for (size_t i = 0; i < N; ++i)
chunks[i % num_threads].push_back(data[i]);
// std::set with mutex
std::set<uint32_t> s;
std::mutex m;
auto set_start = std::chrono::high_resolution_clock::now();
std::vector<std::thread> threads_set;
for (int t = 0; t < num_threads; ++t) {
threads_set.emplace_back([&s, &m, &chunks, t]() {
for (uint32_t x : chunks[t]) {
std::lock_guard<std::mutex> lock(m);
s.insert(x);
}
});
}
for (auto& th : threads_set) th.join();
auto set_end = std::chrono::high_resolution_clock::now();
double set_insert_time = std::chrono::duration<double, std::milli>(set_end - set_start).count();
// SortTrie insert
SortTrie trie;
auto trie_start = std::chrono::high_resolution_clock::now();
std::vector<std::thread> threads_trie;
for (int t = 0; t < num_threads; ++t) {
threads_trie.emplace_back([&trie, &chunks, t]() {
for (uint32_t x : chunks[t]) {
trie.insert(x);
}
});
}
for (auto& th : threads_trie) th.join();
auto trie_end = std::chrono::high_resolution_clock::now();
double trie_insert_time = std::chrono::duration<double, std::milli>(trie_end - trie_start).count();
std::cout << "=== Multithreaded Benchmark with N = " << N << ", T = " << num_threads << " ===\n";
std::cout << "std::set insert time: " << set_insert_time << " ms\n";
std::cout << "SortTrie insert time: " << trie_insert_time << " ms\n";
// --- Multithreaded lookup ---
// std::set lookup
auto set_lookup_start = std::chrono::high_resolution_clock::now();
std::atomic<size_t> set_found{0};
std::vector<std::thread> threads_set_lookup;
for (int t = 0; t < num_threads; ++t) {
threads_set_lookup.emplace_back([&s, &m, &chunks, &set_found, t]() {
for (uint32_t x : chunks[t]) {
std::lock_guard<std::mutex> lock(m);
if (s.find(x) != s.end()) {
set_found.fetch_add(1, std::memory_order_relaxed);
}
}
});
}
for (auto& th : threads_set_lookup) th.join();
auto set_lookup_end = std::chrono::high_resolution_clock::now();
double set_lookup_time = std::chrono::duration<double, std::milli>(set_lookup_end - set_lookup_start).count();
// SortTrie lookup
auto trie_lookup_start = std::chrono::high_resolution_clock::now();
std::atomic<size_t> trie_found{0};
std::vector<std::thread> threads_trie_lookup;
for (int t = 0; t < num_threads; ++t) {
threads_trie_lookup.emplace_back([&trie, &chunks, &trie_found, t]() {
for (uint32_t x : chunks[t]) {
if (trie.find(x)) {
trie_found.fetch_add(1, std::memory_order_relaxed);
}
}
});
}
for (auto& th : threads_trie_lookup) th.join();
auto trie_lookup_end = std::chrono::high_resolution_clock::now();
double trie_lookup_time = std::chrono::duration<double, std::milli>(trie_lookup_end - trie_lookup_start).count();
std::cout << "std::set lookup time: " << set_lookup_time << " ms (found " << set_found.load() << ")\n";
std::cout << "SortTrie lookup time: " << trie_lookup_time << " ms (found " << trie_found.load() << ")\n";
}
int main() {
bench_multithreaded(1'000'000, 8);
return 0;
}
#include <iostream>
#include <set>
#include <vector>
#include <thread>
#include <mutex>
#include <random>
#include <chrono>
#include <atomic>
constexpr int BITS_PER_LEVEL = 6;
constexpr int BRANCHES = 1 << BITS_PER_LEVEL;
constexpr int LEVELS = (64 + BITS_PER_LEVEL - 1) / BITS_PER_LEVEL; // 11 levels
struct SortNode {
std::atomic<uint64_t> sortFinder{0};
void* data = nullptr;
SortNode* children[BRANCHES] = {nullptr};
};
class SortTrie {
public:
SortTrie() : root(new SortNode()) {}
~SortTrie() { clear(root); }
void insert(uint64_t key) {
SortNode* node = root;
for (int i = LEVELS - 1; i >= 0; --i) {
int shift = i * BITS_PER_LEVEL;
int index = (key >> shift) & (BRANCHES - 1);
node->sortFinder.fetch_or(1ULL << index, std::memory_order_relaxed);
SortNode*& child = node->children[index];
if (!child) {
SortNode* new_node = new SortNode();
SortNode* expected = nullptr;
if (!std::atomic_compare_exchange_strong(
reinterpret_cast<std::atomic<SortNode*>*>(&child),
&expected, new_node)) {
delete new_node;
}
}
node = node->children[index];
}
node->data = reinterpret_cast<void*>(static_cast<uintptr_t>(key));
}
bool find(uint64_t key) const {
SortNode* node = root;
for (int i = LEVELS - 1; i >= 0; --i) {
int shift = i * BITS_PER_LEVEL;
int index = (key >> shift) & (BRANCHES - 1);
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed);
if (!(mask & (1ULL << index)))
return false;
node = node->children[index];
if (!node) return false;
}
return reinterpret_cast<uintptr_t>(node->data) == key;
}
void collect_sorted(std::vector<uint64_t>& out) const {
in_order(root, 0, 0, out);
}
private:
SortNode* root;
void in_order(SortNode* node, uint64_t prefix, int level, std::vector<uint64_t>& out) const {
if (!node) return;
if (level == LEVELS) {
out.push_back(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(node->data)));
return;
}
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed);
while (mask) {
int i = __builtin_ctzll(mask);
mask &= ~(1ULL << i);
uint64_t next_prefix = (prefix << BITS_PER_LEVEL) | i;
in_order(node->children[i], next_prefix, level + 1, out);
}
}
void clear(SortNode* node) {
if (!node) return;
for (int i = 0; i < BRANCHES; ++i)
clear(node->children[i]);
delete node;
}
};
void bench_multithreaded(size_t N, int num_threads) {
std::vector<uint64_t> data(N);
std::mt19937_64 rng(12345);
for (auto& x : data) x = rng();
std::vector<std::vector<uint64_t>> chunks(num_threads);
for (size_t i = 0; i < N; ++i)
chunks[i % num_threads].push_back(data[i]);
// std::set with mutex
std::set<uint64_t> s;
std::mutex m;
auto set_start = std::chrono::high_resolution_clock::now();
std::vector<std::thread> threads_set;
for (int t = 0; t < num_threads; ++t) {
threads_set.emplace_back([&s, &m, &chunks, t]() {
for (uint64_t x : chunks[t]) {
std::lock_guard<std::mutex> lock(m);
s.insert(x);
}
});
}
for (auto& th : threads_set) th.join();
auto set_end = std::chrono::high_resolution_clock::now();
double set_insert_time = std::chrono::duration<double, std::milli>(set_end - set_start).count();
// SortTrie insert
SortTrie trie;
auto trie_start = std::chrono::high_resolution_clock::now();
std::vector<std::thread> threads_trie;
for (int t = 0; t < num_threads; ++t) {
threads_trie.emplace_back([&trie, &chunks, t]() {
for (uint64_t x : chunks[t]) {
trie.insert(x);
}
});
}
for (auto& th : threads_trie) th.join();
auto trie_end = std::chrono::high_resolution_clock::now();
double trie_insert_time = std::chrono::duration<double, std::milli>(trie_end - trie_start).count();
std::cout << "=== Multithreaded Benchmark with N = " << N << ", T = " << num_threads << " (64-bit keys) ===\n";
std::cout << "std::set insert time: " << set_insert_time << " ms\n";
std::cout << "SortTrie insert time: " << trie_insert_time << " ms\n";
// --- Multithreaded lookup ---
// std::set lookup
auto set_lookup_start = std::chrono::high_resolution_clock::now();
std::atomic<size_t> set_found{0};
std::vector<std::thread> threads_set_lookup;
for (int t = 0; t < num_threads; ++t) {
threads_set_lookup.emplace_back([&s, &m, &chunks, &set_found, t]() {
for (uint64_t x : chunks[t]) {
std::lock_guard<std::mutex> lock(m);
if (s.find(x) != s.end()) {
set_found.fetch_add(1, std::memory_order_relaxed);
}
}
});
}
for (auto& th : threads_set_lookup) th.join();
auto set_lookup_end = std::chrono::high_resolution_clock::now();
double set_lookup_time = std::chrono::duration<double, std::milli>(set_lookup_end - set_lookup_start).count();
// SortTrie lookup
auto trie_lookup_start = std::chrono::high_resolution_clock::now();
std::atomic<size_t> trie_found{0};
std::vector<std::thread> threads_trie_lookup;
for (int t = 0; t < num_threads; ++t) {
threads_trie_lookup.emplace_back([&trie, &chunks, &trie_found, t]() {
for (uint64_t x : chunks[t]) {
if (trie.find(x)) {
trie_found.fetch_add(1, std::memory_order_relaxed);
}
}
});
}
for (auto& th : threads_trie_lookup) th.join();
auto trie_lookup_end = std::chrono::high_resolution_clock::now();
double trie_lookup_time = std::chrono::duration<double, std::milli>(trie_lookup_end - trie_lookup_start).count();
std::cout << "std::set lookup time: " << set_lookup_time << " ms (found " << set_found.load() << ")\n";
std::cout << "SortTrie lookup time: " << trie_lookup_time << " ms (found " << trie_found.load() << ")\n";
}
int main() {
bench_multithreaded(1'000'000, 8);
return 0;
}
#include <atomic>
#include <cstdint>
#include <cstring>
#include <vector>
#include <mutex>
#include <shared_mutex>
template <typename KeyT, int BitsPerLevel = 6>
class SortTrie {
static_assert(std::is_unsigned<KeyT>::value, "KeyT must be unsigned");
static constexpr int BRANCHES = 1 << BitsPerLevel;
static constexpr int LEVELS = (sizeof(KeyT) * 8 + BitsPerLevel - 1) / BitsPerLevel;
struct SortNode {
std::atomic<uint64_t> sortFinder{0};
std::atomic<void*> children[BRANCHES];
SortNode() {
for (int i = 0; i < BRANCHES; ++i) children[i] = nullptr;
}
};
SortNode* root;
public:
SortTrie() : root(new SortNode()) {}
~SortTrie() { clear(root, 0); }
void insert(KeyT key, void* data) {
SortNode* node = root;
for (int level = 0; level < LEVELS - 1; ++level) {
int shift = (LEVELS - 1 - level) * BitsPerLevel;
int index = (key >> shift) & (BRANCHES - 1);
node->sortFinder.fetch_or(1ULL << index, std::memory_order_relaxed);
void* child = node->children[index].load(std::memory_order_acquire);
if (!child) {
auto* new_node = new SortNode();
if (!node->children[index].compare_exchange_strong(child, new_node))
delete new_node; // another thread inserted
else
child = new_node;
}
node = reinterpret_cast<SortNode*>(child);
}
// Leaf level
int final_index = key & (BRANCHES - 1);
node->sortFinder.fetch_or(1ULL << final_index, std::memory_order_relaxed);
node->children[final_index].store(data, std::memory_order_release);
}
void* find(KeyT key) const {
SortNode* node = root;
for (int level = 0; level < LEVELS - 1; ++level) {
int shift = (LEVELS - 1 - level) * BitsPerLevel;
int index = (key >> shift) & (BRANCHES - 1);
if (!(node->sortFinder.load(std::memory_order_relaxed) & (1ULL << index)))
return nullptr;
void* child = node->children[index].load(std::memory_order_acquire);
if (!child) return nullptr;
node = reinterpret_cast<SortNode*>(child);
}
int final_index = key & (BRANCHES - 1);
if (!(node->sortFinder.load(std::memory_order_relaxed) & (1ULL << final_index)))
return nullptr;
return node->children[final_index].load(std::memory_order_acquire);
}
void collect_sorted(std::vector<KeyT>& out) const {
in_order(root, 0, 0, out);
}
private:
void in_order(SortNode* node, KeyT prefix, int level, std::vector<KeyT>& out) const {
if (!node) return;
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed);
while (mask) {
int i = __builtin_ctzll(mask);
mask &= ~(1ULL << i);
KeyT next_prefix = (prefix << BitsPerLevel) | i;
void* child = node->children[i].load(std::memory_order_acquire);
if (level == LEVELS - 1) {
out.push_back(next_prefix); // reached leaf, where child is actual data
} else {
in_order(reinterpret_cast<SortNode*>(child), next_prefix, level + 1, out);
}
}
}
void clear(SortNode* node, int level) {
if (!node) return;
for (int i = 0; i < BRANCHES; ++i) {
void* ptr = node->children[i].load();
if (ptr) {
if (level < LEVELS - 1)
clear(reinterpret_cast<SortNode*>(ptr), level + 1);
// else it's a data pointer and we don't own/delete it
}
}
delete node;
}
};
#include <iostream>
#include <vector>
#include <atomic>
#include <cstdint>
#include <cassert>
#include <memory>
template <typename KeyType = uint32_t, int BITS_PER_LEVEL = 6>
class SortTrie {
static_assert(std::is_unsigned<KeyType>::value, "KeyType must be unsigned integer");
static constexpr int BRANCHES = 1 << BITS_PER_LEVEL;
static constexpr int LEVELS = (sizeof(KeyType) * 8 + BITS_PER_LEVEL - 1) / BITS_PER_LEVEL;
struct SortNode {
std::atomic<uint64_t> sortFinder{0};
void* data = nullptr;
SortNode* children[BRANCHES] = {nullptr};
~SortNode() {
for (int i = 0; i < BRANCHES; ++i) {
delete children[i];
}
}
};
public:
SortTrie() : root(new SortNode()) {}
~SortTrie() { delete root; }
void insert(KeyType key) {
SortNode* node = root;
for (int i = LEVELS - 1; i >= 0; --i) {
int shift = i * BITS_PER_LEVEL;
int index = static_cast<int>((key >> shift) & (BRANCHES - 1));
node->sortFinder.fetch_or(1ULL << index, std::memory_order_relaxed);
SortNode*& child = node->children[index];
if (!child) {
SortNode* new_node = new SortNode();
SortNode* expected = nullptr;
if (!std::atomic_compare_exchange_strong(
reinterpret_cast<std::atomic<SortNode*>*>(&child),
&expected, new_node)) {
delete new_node;
}
}
node = node->children[index];
}
node->data = reinterpret_cast<void*>(static_cast<uintptr_t>(key));
}
bool find(KeyType key) const {
SortNode* node = root;
for (int i = LEVELS - 1; i >= 0; --i) {
int shift = i * BITS_PER_LEVEL;
int index = static_cast<int>((key >> shift) & (BRANCHES - 1));
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed);
if (!(mask & (1ULL << index))) return false;
node = node->children[index];
if (!node) return false;
}
return reinterpret_cast<uintptr_t>(node->data) == key;
}
void collect_sorted(std::vector<KeyType>& out) const {
in_order(root, 0, 0, out);
}
private:
SortNode* root;
void in_order(SortNode* node, KeyType prefix, int level, std::vector<KeyType>& out) const {
if (!node) return;
if (level == LEVELS) {
out.push_back(static_cast<KeyType>(reinterpret_cast<uintptr_t>(node->data)));
return;
}
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed);
while (mask) {
int i = __builtin_ctzll(mask);
mask &= ~(1ULL << i);
KeyType next_prefix = (prefix << BITS_PER_LEVEL) | static_cast<KeyType>(i);
in_order(node->children[i], next_prefix, level + 1, out);
}
}
};
#include <iostream>
#include <vector>
#include <atomic>
#include <cstdint>
#include <cassert>
#include <memory>
template <typename KeyType = uint32_t, int BITS_PER_LEVEL = 6>
class SortTrie {
static_assert(std::is_unsigned<KeyType>::value, "KeyType must be unsigned integer");
static constexpr int BRANCHES = 1 << BITS_PER_LEVEL;
static constexpr int LEVELS = (sizeof(KeyType) * 8 + BITS_PER_LEVEL - 1) / BITS_PER_LEVEL;
struct SortNode {
std::atomic<uint64_t> sortFinder{0};
void* data = nullptr;
SortNode* children[BRANCHES] = {nullptr};
~SortNode() {
for (int i = 0; i < BRANCHES; ++i) {
delete children[i];
}
}
};
public:
SortTrie() : root(new SortNode()) {}
~SortTrie() { delete root; }
void insert(KeyType key) {
SortNode* node = root;
for (int i = LEVELS - 1; i >= 0; --i) {
int shift = i * BITS_PER_LEVEL;
int index = static_cast<int>((key >> shift) & (BRANCHES - 1));
node->sortFinder.fetch_or(1ULL << index, std::memory_order_relaxed);
SortNode*& child = node->children[index];
if (!child) {
SortNode* new_node = new SortNode();
SortNode* expected = nullptr;
if (!std::atomic_compare_exchange_strong(
reinterpret_cast<std::atomic<SortNode*>*>(&child),
&expected, new_node)) {
delete new_node;
}
}
node = node->children[index];
}
node->data = reinterpret_cast<void*>(static_cast<uintptr_t>(key));
}
bool find(KeyType key) const {
SortNode* node = root;
for (int i = LEVELS - 1; i >= 0; --i) {
int shift = i * BITS_PER_LEVEL;
int index = static_cast<int>((key >> shift) & (BRANCHES - 1));
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed);
if (!(mask & (1ULL << index))) return false;
node = node->children[index];
if (!node) return false;
}
return reinterpret_cast<uintptr_t>(node->data) == key;
}
void collect_sorted(std::vector<KeyType>& out) const {
in_order(root, 0, 0, out);
}
// New query method: collects keys between low and high inclusive
void query_range(KeyType low, KeyType high, std::vector<KeyType>& out) const {
query_range_helper(root, 0, 0, low, high, out);
}
private:
SortNode* root;
void in_order(SortNode* node, KeyType prefix, int level, std::vector<KeyType>& out) const {
if (!node) return;
if (level == LEVELS) {
out.push_back(static_cast<KeyType>(reinterpret_cast<uintptr_t>(node->data)));
return;
}
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed);
while (mask) {
int i = __builtin_ctzll(mask);
mask &= ~(1ULL << i);
KeyType next_prefix = (prefix << BITS_PER_LEVEL) | static_cast<KeyType>(i);
in_order(node->children[i], next_prefix, level + 1, out);
}
}
void query_range_helper(SortNode* node, KeyType prefix, int level, KeyType low, KeyType high, std::vector<KeyType>& out) const {
if (!node) return;
if (level == LEVELS) {
KeyType val = static_cast<KeyType>(reinterpret_cast<uintptr_t>(node->data));
if (val >= low && val <= high) {
out.push_back(val);
}
return;
}
// Calculate the bit range that this node covers for keys
int shift = (LEVELS - level - 1) * BITS_PER_LEVEL;
KeyType node_range_size = KeyType(1) << shift;
uint64_t mask = node->sortFinder.load(std::memory_order_relaxed);
while (mask) {
int i = __builtin_ctzll(mask);
mask &= ~(1ULL << i);
KeyType child_min = (prefix << BITS_PER_LEVEL) | static_cast<KeyType>(i);
child_min <<= shift;
KeyType child_max = child_min + node_range_size - 1;
// Prune subtree if out of range
if (child_max < low || child_min > high) {
continue;
}
// Otherwise recurse
KeyType next_prefix = (prefix << BITS_PER_LEVEL) | static_cast<KeyType>(i);
query_range_helper(node->children[i], next_prefix, level + 1, low, high, out);
}
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment