Skip to content

Instantly share code, notes, and snippets.

@jweinst1
Created February 10, 2025 23:21
Show Gist options
  • Save jweinst1/1ea695a1f3db66bd519422436737457c to your computer and use it in GitHub Desktop.
Save jweinst1/1ea695a1f3db66bd519422436737457c to your computer and use it in GitHub Desktop.
map to store pointers with hazard RC
#include <atomic>
#include <thread>
#include <chrono>
#include <cstdint>
#include <limits>
#include <cstdio>
#include <string>
#include <memory>
#include <vector>
struct PtrStore {
std::atomic<size_t> _refCount = 0;
std::atomic<void*> _ptr = nullptr;
bool claim(void* ptr) {
void* current = _ptr.load();
if (current != nullptr) return false;
if (_ptr.compare_exchange_strong(current, ptr)) {
_refCount.store(1);
return true;
}
return false;
}
void* decRef() {
size_t seen = _refCount.load();
if (seen == 0) {
return nullptr;
}
while(!_refCount.compare_exchange_weak(seen, seen - 1)) {
if (seen == 0) {
return nullptr;
}
}
if (seen == 1) {
return _ptr.exchange(nullptr);
}
return nullptr;
}
void incRefUnchecked() {
_refCount.fetch_add(1);
}
bool incRef() {
size_t seen = _refCount.load();
if (seen == 0)
return false;
while(!_refCount.compare_exchange_weak(seen, seen + 1)) {
if (seen == 0)
return false;
}
return true;
}
};
struct PtrHash {
size_t size = 0;
PtrStore* table = nullptr;
std::atomic<PtrHash*> next = nullptr;
PtrHash(size_t width) {
size = width;
table = new PtrStore[size];
}
PtrStore* findAndStore(void* obj) {
size_t hashSpot = ((size_t)obj) % size;
size_t attempts = 0;
size_t skew = hashSpot ^ (size_t)obj;
PtrStore* candidate = &table[hashSpot];
while (!candidate->claim(obj)) {
//printf("hash %zu\n", hashSpot);
skew = hashSpot ^ skew;
hashSpot = skew % size;
candidate = &table[hashSpot];
if (++attempts > size)
return nullptr;
}
return candidate;
}
~PtrHash() {
delete[] table;
}
};
static void thread_test() {
using namespace std::chrono_literals;
std::thread tpool[8];
std::atomic<bool> keepGoing = true;
constexpr size_t usedWidth = 49;
PtrHash h(usedWidth);
for (int i = 0; i < 8; ++i)
{
tpool[i] = std::thread([&]{
size_t counter = 0;
size_t failed = 0;
while (keepGoing.load()) {
int* num = new int(8);
PtrStore* foo = h.findAndStore(num);
if (foo != nullptr) {
foo->decRef();
++counter;
} else {
++failed;
}
}
printf("%zu ops completed, %zu ops failed\n", counter, failed);
});
}
std::this_thread::sleep_for(500ms);
keepGoing.store(false);
printf("now joining\n");
for (int j = 0; j < 8; ++j)
{
tpool[j].join();
}
}
int main(int argc, char const *argv[])
{
thread_test();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment