Last active
March 21, 2020 10:12
-
-
Save oxyflour/d2e9e360496778c2d6fe3ee0893dfd07 to your computer and use it in GitHub Desktop.
lock free set
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <stdio.h> | |
| // Reference: | |
| // https://nosferalatu.com/SimpleGPUHashTable.html | |
| // replace the function with cuda implement | |
| unsigned long atomicCAS(unsigned long *addr, unsigned long comp, unsigned long val) { | |
| auto prev = *addr; | |
| if (*addr == comp) { | |
| *addr = val; | |
| } | |
| return prev; | |
| } | |
| template<typename T> T* cas(T **addr, T *comp, T* val) { | |
| return (T *) atomicCAS((unsigned long *) addr, (unsigned long) comp, (unsigned long) val); | |
| } | |
| const int FLOAT_FAC = 1e6; | |
| typedef struct float2 { | |
| float x; | |
| float y; | |
| auto hash() { | |
| int v = x * FLOAT_FAC + y * FLOAT_FAC; | |
| auto p = (unsigned int *) &v; | |
| return *p; | |
| } | |
| auto eq(float2 *v) { | |
| return (v->x - x) * (v->x - x) + (v->y - y) * (v->y - y) < 1e-6; | |
| } | |
| } float2; | |
| template <typename T> class set { | |
| T** ptrs; | |
| T* data; | |
| int size = 0; | |
| int maxn = 1024; | |
| public: | |
| set(int maxn = 1024) { | |
| this->maxn = maxn; | |
| ptrs = new T *[maxn]; | |
| data = new T[maxn]; | |
| for (int i = 0; i < maxn; i ++) { | |
| ptrs[i] = NULL; | |
| } | |
| } | |
| ~set() { | |
| delete [] ptrs; | |
| delete [] data; | |
| } | |
| auto len() { | |
| return size; | |
| } | |
| int add(T *val) { | |
| auto slot = val->hash() % maxn, count = 0; | |
| while (count ++ < maxn) { | |
| auto prev = cas(ptrs + slot, (T *) NULL, val); | |
| if (prev == NULL) { | |
| data[slot] = *val; | |
| ptrs[slot] = data + slot; | |
| size ++; | |
| return slot; | |
| } else if (prev->eq(val)) { | |
| return slot; | |
| } | |
| slot = (slot + 1) % maxn; | |
| } | |
| return -1; | |
| } | |
| int remove(T *val) { | |
| auto slot = val->hash() % maxn, count = 0; | |
| while (count ++ < maxn) { | |
| auto prev = ptrs[slot]; | |
| if (prev == NULL) { | |
| return -1; | |
| } else if (prev->eq(val)) { | |
| ptrs[slot] = NULL; | |
| size --; | |
| return slot; | |
| } | |
| slot = (slot + 1) % maxn; | |
| } | |
| } | |
| // iterators | |
| private: | |
| int index = 0; | |
| T* iter() { | |
| while (index < maxn && ptrs[index] == NULL) { | |
| index ++; | |
| } | |
| if (index < maxn) { | |
| return ptrs[index]; | |
| } else { | |
| return NULL; | |
| } | |
| } | |
| public: | |
| T *begin() { | |
| index = 0; | |
| return iter(); | |
| } | |
| T *next() { | |
| index ++; | |
| return iter(); | |
| } | |
| }; | |
| int main() { | |
| set<float2> s; | |
| for (auto i = 0; i < 40; i ++) { | |
| float2 p = { i * 0.1, 0.2 }; | |
| s.add(&p); | |
| } | |
| printf("len: %d\n", s.len()); | |
| for (auto i = s.begin(); i; i = s.next()) { | |
| printf("p: %f %f\n", i->x, i->y); | |
| } | |
| return 0; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment