Created
April 4, 2013 15:26
-
-
Save wilburding/5311360 to your computer and use it in GitHub Desktop.
a lock free stack .
original work http://developers.memsql.com/blog/common-pitfalls-in-writing-lock-free-algorithms/
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <string> | |
#include <mutex> | |
#include <future> | |
#include <chrono> | |
#include <random> | |
#include <utility> | |
#include <assert.h> | |
using namespace std; | |
enum { ITERATIONS = 10 * 1000 * 1000 }; | |
class BinaryBackOff | |
{ | |
public: | |
// smallest effective time is about 20us | |
BinaryBackOff(chrono::microseconds initial_backoff = chrono::microseconds{10}, | |
chrono::microseconds max_backoff = chrono::microseconds{640}) | |
:initial_backoff_(initial_backoff), | |
max_backoff_(max_backoff) | |
{} | |
void backoff() | |
{ | |
return; | |
this_thread::sleep_for(initial_backoff_); | |
if(initial_backoff_ < max_backoff_) | |
{ | |
initial_backoff_ *= 2; | |
} | |
} | |
private: | |
chrono::microseconds initial_backoff_; | |
chrono::microseconds max_backoff_; | |
}; | |
template<class T> | |
class TaggedPointer | |
{ | |
public: | |
explicit TaggedPointer(T* ptr = nullptr) | |
:ptr_(ptr), | |
counter_(0) | |
{} | |
inline T* load_ptr() noexcept | |
{ | |
return ptr_.load(memory_order_acquire); | |
} | |
inline uint64_t load_counter() noexcept | |
{ | |
return counter_.load(memory_order_acquire); | |
} | |
inline bool compare_exchange(T* expected_ptr, uint64_t expected_counter, T* desired_ptr, uint64_t desired_counter) noexcept | |
{ | |
bool result; | |
asm volatile ( | |
"lock cmpxchg16b %0;" | |
"setz %3;" | |
:"+m"(*this), "+a"(expected_ptr), "+d"(expected_counter), "=q"(result) | |
:"b"(desired_ptr), "c"(desired_counter) | |
:"cc", "memory" | |
); | |
return result; | |
} | |
private: | |
// atomic in case of compiler keep in registers? | |
atomic<T*> ptr_; | |
atomic<uint64_t> counter_; | |
} __attribute__ (( __aligned__(16) )); | |
struct Node | |
{ | |
int value; | |
Node* next; // need be atomic? | |
}; | |
class Stack | |
{ | |
public: | |
inline bool try_push(Node* node) | |
{ | |
Node* head = head_.load_ptr(); | |
uint64_t counter = head_.load_counter(); | |
//node->next.store(head, memory_order_relaxed); | |
node->next = head; | |
return head_.compare_exchange(head, counter, node, counter + 1); | |
} | |
void push(Node* node) | |
{ | |
BinaryBackOff bb; | |
while(!try_push(node)) | |
{ | |
bb.backoff(); | |
} | |
} | |
inline bool try_pop(int& value) | |
{ | |
Node* head = head_.load_ptr(); | |
uint64_t counter = head_.load_counter(); | |
if(!head) | |
{ | |
value = -1; | |
return true; | |
} | |
if(head_.compare_exchange(head, counter, head->next, counter + 1)) | |
{ | |
/*delete head;*/ | |
value = head->value; | |
return true; | |
} | |
else | |
{ | |
return false; | |
} | |
} | |
int pop() | |
{ | |
int res; | |
BinaryBackOff bb; | |
while(!try_pop(res)) | |
{ | |
bb.backoff(); | |
} | |
return res; | |
} | |
private: | |
TaggedPointer<Node> head_; | |
}; | |
mutex stat_lock; | |
vector<int> total_pushes; | |
vector<int> total_pops; | |
void worker_correctness(int id, shared_ptr<Stack> stack) | |
{ | |
minstd_rand rd; | |
uniform_int_distribution<> uid(0, 1); | |
vector<int> pushes; | |
pushes.reserve(ITERATIONS * 3 / 2); | |
vector<int> pops; | |
pops.reserve(ITERATIONS * 3 / 2); | |
auto begin_time = chrono::high_resolution_clock::now(); | |
for(int i = 0; i < ITERATIONS; ++i) | |
{ | |
if(uid(rd) == 0) | |
{ | |
pushes.push_back(i); | |
stack->push(new Node{i, nullptr}); | |
} | |
else | |
{ | |
auto value = stack->pop(); | |
if(value >= 0) | |
pops.push_back(value); | |
} | |
} | |
while(true) | |
{ | |
auto value = stack->pop(); | |
if(value >= 0) | |
pops.push_back(value); | |
else | |
break; | |
} | |
auto end_time = chrono::high_resolution_clock::now(); | |
{ | |
lock_guard<mutex> holder(stat_lock); | |
total_pushes.insert(end(total_pushes), begin(pushes), end(pushes)); | |
total_pops.insert(end(total_pops), begin(pops), end(pops)); | |
} | |
printf("%d: %lldms\n", id, chrono::duration_cast<chrono::milliseconds>(end_time - begin_time).count()); | |
} | |
void test_correctness() | |
{ | |
auto stack = make_shared<Stack>(); | |
thread threads[] = { | |
thread{worker_correctness, 1, stack}, | |
thread{worker_correctness, 2, stack}, | |
thread{worker_correctness, 3, stack}, | |
thread{worker_correctness, 4, stack} | |
}; | |
for(auto& thread: threads) | |
{ | |
thread.join(); | |
} | |
sort(begin(total_pushes), end(total_pushes)); | |
sort(begin(total_pops), end(total_pops)); | |
if(total_pushes == total_pops) | |
{ | |
cout << "good" << endl; | |
} | |
else | |
{ | |
cout << "bad" << endl; | |
} | |
} | |
void worker_speed(int id, Stack& stack, Node* pool) | |
{ | |
auto begin_time = chrono::high_resolution_clock::now(); | |
for(int i = 0; i < ITERATIONS; ++i) | |
{ | |
if((i & 1) == 0) | |
{ | |
pool[i].value = i; | |
stack.push(&pool[i]); | |
} | |
else | |
{ | |
stack.pop(); | |
} | |
} | |
while(true) | |
{ | |
if(stack.pop() < 0) | |
break; | |
} | |
auto end_time = chrono::high_resolution_clock::now(); | |
printf("%d: %lldms\n", id, chrono::duration_cast<chrono::milliseconds>(end_time - begin_time).count()); | |
} | |
void test_speed() | |
{ | |
auto stack = make_shared<Stack>(); | |
Node* pool = new Node[ITERATIONS * 4]; | |
thread threads[] = { | |
thread{worker_speed, 1, ref(*stack.get()), pool}, | |
thread{worker_speed, 2, ref(*stack.get()), pool + ITERATIONS}, | |
thread{worker_speed, 3, ref(*stack.get()), pool + 2 * ITERATIONS}, | |
thread{worker_speed, 4, ref(*stack.get()), pool + 3 * ITERATIONS} | |
}; | |
for(auto& thread: threads) | |
{ | |
thread.join(); | |
} | |
} | |
template<class F, class ...Args> | |
void timeit(F f, uint32_t repeat, Args&&... args) | |
{ | |
auto begin_time = chrono::high_resolution_clock::now(); | |
for(uint32_t i = 0; i < repeat; ++i) | |
f(forward<Args>(args)...); | |
auto end_time = chrono::high_resolution_clock::now(); | |
printf("total time: %lldus\n", chrono::duration_cast<chrono::microseconds>(end_time - begin_time).count()); | |
printf("average time: %lldns\n", chrono::duration_cast<chrono::nanoseconds>(end_time - begin_time).count() / repeat); | |
} | |
int main(int argc, char* argv[]) | |
{ | |
test_correctness(); | |
test_speed(); | |
/* | |
* | |
* timeit(rand, 10 * 1000 * 1000); | |
* | |
* minstd_rand mr; | |
* uniform_int_distribution<> uid(0, 1000000); | |
* timeit([&uid](minstd_rand& r){ uid(r); }, 10 * 1000 * 1000, ref(mr)); | |
*/ | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment