Last active
March 26, 2023 04:15
-
-
Save jcdickinson/ec2b93f78afc4c72ae74 to your computer and use it in GitHub Desktop.
C++ Lock-Free Work Stealing Stack
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <atomic> | |
// A lock-free stack. | |
// Push = single producer | |
// Pop = single consumer (same thread as push) | |
// Steal = multiple consumer | |
// All methods, including Push, may fail. Re-issue the request | |
// if that occurs (spinwait). | |
template<class T, size_t capacity = 131072> | |
class WorkStealingStack { | |
public: | |
inline WorkStealingStack() { | |
_top = 1; | |
_bottom = 1; | |
} | |
WorkStealingStack(const WorkStealingStack&) = delete; | |
inline ~WorkStealingStack() | |
{ | |
} | |
// Single producer | |
inline bool Push(const T& item) { | |
auto oldtop = _top.load(std::memory_order_relaxed); | |
auto oldbottom = _bottom.load(std::memory_order_relaxed); | |
auto numtasks = oldbottom - oldtop; | |
if ( | |
oldbottom > oldtop && // size_t is unsigned, validate the result is positive | |
numtasks >= capacity - 1) { | |
// The caller can decide what to do, they will probably spinwait. | |
return false; | |
} | |
_values[oldbottom % capacity].store(item, std::memory_order_relaxed); | |
_bottom.fetch_add(1, std::memory_order_release); | |
return true; | |
} | |
// Single consumer | |
inline bool Pop(T& result) { | |
size_t oldtop, oldbottom, newtop, newbottom, ot; | |
oldbottom = _bottom.fetch_sub(1, std::memory_order_release); | |
ot = oldtop = _top.load(std::memory_order_acquire); | |
newtop = oldtop + 1; | |
newbottom = oldbottom - 1; | |
// Bottom has wrapped around. | |
if (oldbottom < oldtop) { | |
_bottom.store(oldtop, std::memory_order_relaxed); | |
return false; | |
} | |
// The queue is empty. | |
if (oldbottom == oldtop) { | |
_bottom.fetch_add(1, std::memory_order_release); | |
return false; | |
} | |
// Make sure that we are not contending for the item. | |
if (newbottom == oldtop) { | |
auto ret = _values[newbottom % capacity].load(std::memory_order_relaxed); | |
if (!_top.compare_exchange_strong(oldtop, newtop, std::memory_order_acquire)) { | |
_bottom.fetch_add(1, std::memory_order_release); | |
return false; | |
} | |
else { | |
result = ret; | |
_bottom.store(newtop, std::memory_order_release); | |
return true; | |
} | |
} | |
// It's uncontended. | |
result = _values[newbottom % capacity].load(std::memory_order_acquire); | |
return true; | |
} | |
// Multiple consumer. | |
inline bool Steal(T& result) { | |
size_t oldtop, newtop, oldbottom; | |
oldtop = _top.load(std::memory_order_acquire); | |
oldbottom = _bottom.load(std::memory_order_relaxed); | |
newtop = oldtop + 1; | |
if (oldbottom <= oldtop) | |
return false; | |
// Make sure that we are not contending for the item. | |
if (!_top.compare_exchange_strong(oldtop, newtop, std::memory_order_acquire)) { | |
return false; | |
} | |
result = _values[oldtop % capacity].load(std::memory_order_relaxed); | |
return true; | |
} | |
private: | |
// Circular array | |
std::atomic<T> _values[capacity]; | |
std::atomic<size_t> _top; // queue | |
std::atomic<size_t> _bottom; // stack | |
}; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "stdafx.h" | |
#include <thread> | |
#include <functional> | |
#include <chrono> | |
#include "workstealingstack.h" | |
#include "catch.h" | |
using namespace std; | |
TEST_CASE("Work stealing stack: Single-threaded push and pop", "[wss][serial]") { | |
auto wss = make_unique<WorkStealingStack<int>>(); | |
wss->Push(100); | |
wss->Push(200); | |
wss->Push(300); | |
wss->Push(400); | |
int value[5]; | |
bool success[5]; | |
success[0] = wss->Pop(value[0]); | |
success[1] = wss->Pop(value[1]); | |
success[2] = wss->Pop(value[2]); | |
success[3] = wss->Pop(value[3]); | |
success[4] = wss->Pop(value[4]); | |
REQUIRE(success[0]); | |
REQUIRE(success[1]); | |
REQUIRE(success[2]); | |
REQUIRE(success[3]); | |
REQUIRE_FALSE(success[4]); | |
REQUIRE(value[0] == 400); | |
REQUIRE(value[1] == 300); | |
REQUIRE(value[2] == 200); | |
REQUIRE(value[3] == 100); | |
} | |
TEST_CASE("Work stealing stack: Single-threaded push and steal", "[wss][serial]") { | |
auto wss = make_unique<WorkStealingStack<int>>(); | |
wss->Push(100); | |
wss->Push(200); | |
wss->Push(300); | |
wss->Push(400); | |
int value[5]; | |
bool success[5]; | |
success[0] = wss->Steal(value[0]); | |
success[1] = wss->Steal(value[1]); | |
success[2] = wss->Steal(value[2]); | |
success[3] = wss->Steal(value[3]); | |
success[4] = wss->Steal(value[4]); | |
REQUIRE(success[0]); | |
REQUIRE(success[1]); | |
REQUIRE(success[2]); | |
REQUIRE(success[3]); | |
REQUIRE_FALSE(success[4]); | |
REQUIRE(value[0] == 100); | |
REQUIRE(value[1] == 200); | |
REQUIRE(value[2] == 300); | |
REQUIRE(value[3] == 400); | |
} | |
TEST_CASE("Work stealing stack: Single-threaded push, pop and steal", "[wss][serial]") { | |
auto wss = make_unique<WorkStealingStack<int>>(); | |
int value[5]; | |
bool success[5]; | |
wss->Push(100); | |
wss->Push(200); | |
success[0] = wss->Pop(value[0]); | |
wss->Push(300); | |
success[1] = wss->Steal(value[1]); | |
wss->Push(400); | |
success[2] = wss->Steal(value[2]); | |
success[3] = wss->Pop(value[3]); | |
success[4] = wss->Steal(value[4]); | |
REQUIRE(success[0]); | |
REQUIRE(success[1]); | |
REQUIRE(success[2]); | |
REQUIRE(success[3]); | |
REQUIRE_FALSE(success[4]); | |
REQUIRE(value[0] == 200); | |
REQUIRE(value[1] == 100); | |
REQUIRE(value[2] == 300); | |
REQUIRE(value[3] == 400); | |
} | |
TEST_CASE("Work stealing stack: Mulithreaded one consumer one producer", "[wss][concurrent]") { | |
auto wss = make_unique<WorkStealingStack<int, 200>>(); | |
auto done = false; | |
auto result = 0; | |
auto count = 0; | |
thread consumer([&]() { | |
while (!done) { | |
int val; | |
while (wss->Steal(val)) { | |
count++; | |
result += val % 101; | |
} | |
} | |
}); | |
thread producer([&]() { | |
for (auto i = 1; i <= 10000; i++) { | |
while (!wss->Push(i)){} | |
} | |
this_thread::sleep_for(chrono::seconds(1)); | |
done = true; | |
}); | |
consumer.join(); | |
producer.join(); | |
REQUIRE(count == 10000); | |
REQUIRE(result == 499951); | |
} | |
TEST_CASE("Work stealing stack: Mulithreaded one consumer one producer large iteration", "[wss][concurrent]") { | |
auto wss = make_unique<WorkStealingStack<int, 200>>(); | |
auto done = false; | |
atomic<int> result = 0; | |
thread consumer([&]() { | |
auto i = 0; | |
while (!done || (++i < 10000)) { | |
int val; | |
while (wss->Steal(val)) { | |
result.fetch_add(1); | |
} | |
} | |
}); | |
thread producer([&]() { | |
for (auto i = 1; i <= 100000; i++) { | |
while (!wss->Push(i)){} | |
} | |
done = true; | |
}); | |
consumer.join(); | |
producer.join(); | |
auto r = result.load(); | |
REQUIRE(r == 100000); | |
} | |
TEST_CASE("Work stealing stack: Mulithreaded many consumers one producer", "[wss][concurrent]") { | |
auto wss = make_unique<WorkStealingStack<int, 200>>(); | |
auto done = false; | |
atomic<int> result = 0; | |
atomic<int> count = 0; | |
auto consumers = new thread[50]; | |
for (auto i = 0; i < 50; i++) { | |
consumers[i] = thread([&]() { | |
while (!done) { | |
int val; | |
while (wss->Steal(val)) { | |
count.fetch_add(1); | |
result.fetch_add(val % 101); | |
} | |
} | |
}); | |
} | |
thread producer([&]() { | |
for (auto i = 1; i <= 10000; i++) { | |
while (!wss->Push(i)){} | |
} | |
this_thread::sleep_for(chrono::seconds(1)); | |
done = true; | |
}); | |
for (auto i = 0; i < 50; i++) { | |
consumers[i].join(); | |
} | |
producer.join(); | |
auto c = count.load(); | |
auto r = result.load(); | |
REQUIRE(c == 10000); | |
REQUIRE(r == 499951); | |
} | |
TEST_CASE("Work stealing stack: Mulithreaded many consumers one consuming producer", "[wss][concurrent]") { | |
auto wss = make_unique<WorkStealingStack<int>>(); | |
auto done = false; | |
auto selfConsumed = false; | |
atomic<int> result = 0; | |
atomic<int> count1 = 0; | |
atomic<int> count2 = 0; | |
auto bb = new int[10000]; | |
auto consumers = new thread[11]; | |
for (auto i = 0; i < 11; i++) { | |
consumers[i] = thread([&]() { | |
while (!done) { | |
int val; | |
while (wss->Steal(val)) { | |
bb[val] = 1; | |
count1.fetch_add(1); | |
result.fetch_add(val % 101); | |
} | |
} | |
}); | |
} | |
thread producer([&]() { | |
for (auto i = 0; i < 10000; i++) { | |
while (!wss->Push(i)){} | |
if (i % 7 == 0) { | |
int val; | |
while (wss->Pop(val)) { | |
bb[val] = 2; | |
selfConsumed = true; | |
count2.fetch_add(1); | |
result.fetch_add(val % 101); | |
} | |
} | |
} | |
this_thread::sleep_for(chrono::seconds(1)); | |
done = true; | |
}); | |
for (auto i = 0; i < 11; i++) { | |
consumers[i].join(); | |
} | |
producer.join(); | |
if (!selfConsumed) { | |
WARN("The producer did not consume any of its own items, test result may be compromised."); | |
} | |
auto c = count1.load() + count2.load(); | |
auto r = result.load(); | |
REQUIRE(c == 10000); | |
REQUIRE(r == 499950); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Don't suppose you'd consider making this code available under a permissive license (MIT/BSD)? Thanks!