Skip to content

Instantly share code, notes, and snippets.

@oxyflour
Last active March 21, 2020 10:12
Show Gist options
  • Select an option

  • Save oxyflour/d2e9e360496778c2d6fe3ee0893dfd07 to your computer and use it in GitHub Desktop.

Select an option

Save oxyflour/d2e9e360496778c2d6fe3ee0893dfd07 to your computer and use it in GitHub Desktop.
lock free set
#include <stdio.h>
// Reference:
// https://nosferalatu.com/SimpleGPUHashTable.html
// replace the function with cuda implement
unsigned long atomicCAS(unsigned long *addr, unsigned long comp, unsigned long val) {
auto prev = *addr;
if (*addr == comp) {
*addr = val;
}
return prev;
}
template<typename T> T* cas(T **addr, T *comp, T* val) {
return (T *) atomicCAS((unsigned long *) addr, (unsigned long) comp, (unsigned long) val);
}
const int FLOAT_FAC = 1e6;
typedef struct float2 {
float x;
float y;
auto hash() {
int v = x * FLOAT_FAC + y * FLOAT_FAC;
auto p = (unsigned int *) &v;
return *p;
}
auto eq(float2 *v) {
return (v->x - x) * (v->x - x) + (v->y - y) * (v->y - y) < 1e-6;
}
} float2;
template <typename T> class set {
T** ptrs;
T* data;
int size = 0;
int maxn = 1024;
public:
set(int maxn = 1024) {
this->maxn = maxn;
ptrs = new T *[maxn];
data = new T[maxn];
for (int i = 0; i < maxn; i ++) {
ptrs[i] = NULL;
}
}
~set() {
delete [] ptrs;
delete [] data;
}
auto len() {
return size;
}
int add(T *val) {
auto slot = val->hash() % maxn, count = 0;
while (count ++ < maxn) {
auto prev = cas(ptrs + slot, (T *) NULL, val);
if (prev == NULL) {
data[slot] = *val;
ptrs[slot] = data + slot;
size ++;
return slot;
} else if (prev->eq(val)) {
return slot;
}
slot = (slot + 1) % maxn;
}
return -1;
}
int remove(T *val) {
auto slot = val->hash() % maxn, count = 0;
while (count ++ < maxn) {
auto prev = ptrs[slot];
if (prev == NULL) {
return -1;
} else if (prev->eq(val)) {
ptrs[slot] = NULL;
size --;
return slot;
}
slot = (slot + 1) % maxn;
}
}
// iterators
private:
int index = 0;
T* iter() {
while (index < maxn && ptrs[index] == NULL) {
index ++;
}
if (index < maxn) {
return ptrs[index];
} else {
return NULL;
}
}
public:
T *begin() {
index = 0;
return iter();
}
T *next() {
index ++;
return iter();
}
};
int main() {
set<float2> s;
for (auto i = 0; i < 40; i ++) {
float2 p = { i * 0.1, 0.2 };
s.add(&p);
}
printf("len: %d\n", s.len());
for (auto i = s.begin(); i; i = s.next()) {
printf("p: %f %f\n", i->x, i->y);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment