Created
May 25, 2022 16:39
-
-
Save kassane/14fb47e2e2c95452b5df857629e8c9a9 to your computer and use it in GitHub Desktop.
Lewis Baker: https://twitter.com/lewissbaker/status/1528563836995125248?t=LdBu-IXre-UTDaCTdW2nBg&s=09
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
///////////////////////////////////////////////////////////////////////// | |
// This is a minimal example that shows the 4 main building-blocks needed to | |
// write concurrent/async coroutine code. | |
// | |
// 1. A coroutine type that lets users write their coroutine logic | |
// and call and co_await other coroutines that they write. | |
// This allows composition of async coroutine code. | |
// | |
// This example defines a basic `task<T>` type that fulfills this purpose. | |
// | |
// 2. A mechanism for launching concurrent async operations and later | |
// waiting for launched concurrent work to complete. | |
// | |
// To be able to have multiple coroutines executing independently we need | |
// some way to introduce concurrency. And to ensure that we are able to | |
// safely compose concurrent operations and shut down cleanly, we need | |
// some way to be able to wait for concurrent operations to complete so | |
// that we can. e.g. ensure that the completion of those concurrent operations | |
// "happens before" the destruction of resources used by those concurrent | |
// operations. | |
// | |
// This example defines a simple `async_scope` facility that lets you | |
// spawn multiple coroutines that can run independently, keeping an atomic | |
// reference count of the number of unfinished tasks that have been launched. | |
// It also gives a way to asynchronously wait until all tasks have completed. | |
// | |
// 3. A mechanism for blocking a thread until some async operation completes. | |
// | |
// The main() function is a synchronous function and so if we launch some | |
// async code we need some way to be able to block until that code completes | |
// so that we don't return from main() until all concurrent work has been | |
// joined. | |
// | |
// This example defines a sync_wait() function that takes an awaitable which | |
// it co_awaits and then blocks the current thread until the co_await expression | |
// completes, returning the result of the co_await expression. | |
// | |
// 4. A mechanism for multiplexing multiple coroutines onto a set of worker threads. | |
// | |
// One of the main reasons for writing asynchronous code is to allow a thread | |
// to do something else while waiting for some operation to complete. This | |
// requires some way to schedule/multiplex multiple coroutines onto a smaller | |
// number of threads, typically using a queue and having an event-loop run by | |
// each thread that allows it to do some work until that work suspends and then | |
// pick up the next item in the queue and execute that in the meantime to keep | |
// the thread busy. | |
// | |
// This example provides a basic `manual_event_loop` implementation that allows | |
// a coroutine to `co_await loop.schedule()` to suspend and enqueue itself to | |
// the loop's queue, whereby a thread that is calling `loop.run()` will eventually | |
// pick up that item and resume it. | |
// | |
// In practice, such multiplexers often also support other kinds of scheduling | |
// such as 'schedule when an I/O operation completes' or 'schedule when a time | |
// elapses'. | |
// | |
// These 4 components are essential to being able to write asynchronous coroutine code. | |
// | |
// Different coroutine library implementations may structure these facilities in | |
// different ways, sometimes combining these items into one abstraction. e.g. sometimes | |
// a multiplexer implementation might combine items 2, 3 and 4 by providing a mechanism | |
// to launch a coroutine on that multiplexer and also wait for all launched work on | |
// that multiplexer. | |
// | |
// This example choses to separate them so that you can understand each component | |
// separately - each of the classes are relatively short (roughly 100 lines) so | |
// should hopefully be relatively easy to study. | |
// | |
// However, keeping them separate also generally gives better flexibility with | |
// how to compose them into an application. e.g. see how we can reuse async_scope | |
// in the nested_scopes() example below. | |
// | |
// This example also defines a number of helper concepts/traits needed by some of | |
// the implementations: | |
// - `awaitable` concept | |
// - `awaiter` concept | |
// - `await_result_t<T>` type-trait | |
// - `awaiter_type_t<T>` type-trait | |
// - `get_awaiter(x) -> awaiter` helper function | |
// And some other helpers: | |
// - `lazy_task` - useful for improving coroutine allocation-elision | |
// - `scope_guard` | |
// | |
// | |
// Please feel free to use this code however you like - it is primarily intended | |
// for learning coroutine mechanisms rather than necessarily as production-quality | |
// code. However, attribution is appreciated if you do use it somewhere. | |
// | |
// By Lewis Baker <[email protected]> | |
///////////////////////////////////////////////////////////////////////// | |
#include <atomic> | |
#include <cassert> | |
#include <concepts> | |
#include <condition_variable> | |
#include <coroutine> | |
#include <cstdio> | |
#include <exception> | |
#include <mutex> | |
#include <optional> | |
#include <semaphore> | |
#include <stop_token> | |
#include <thread> | |
#include <utility> | |
#include <variant> | |
/////////////////////////////////////////////////// | |
// general helpers | |
#define FORCE_INLINE __attribute__((always_inline)) | |
template<typename T> | |
concept decay_copyable = std::constructible_from<std::decay_t<T>, T>; | |
template<typename F> | |
struct scope_guard { | |
F func; | |
bool cancelled{false}; | |
template<typename F2> | |
requires std::constructible_from<F, F2> | |
explicit scope_guard(F2&& f) noexcept(std::is_nothrow_constructible_v<F, F2>) | |
: func(static_cast<F2>(f)) | |
{} | |
scope_guard(scope_guard&& g) noexcept requires std::is_nothrow_move_constructible_v<F> | |
: func(std::move(g.func)) | |
, cancelled(std::exchange(g.cancelled, true)) {} | |
~scope_guard() { | |
call_now(); | |
} | |
void cancel() noexcept { cancelled = true; } | |
void call_now() noexcept { | |
if (!cancelled) { | |
cancelled = true; | |
func(); | |
} | |
} | |
}; | |
template<typename F> | |
scope_guard(F) -> scope_guard<F>; | |
/////////////////////////////////////////////////// | |
// coroutine helpers | |
// Concept that checks if a type is a valid "awaiter" type. | |
// i.e. has the await_ready/await_suspend/await_resume methods. | |
// | |
// Note that we can't check whether await_suspend() is valid here because | |
// we don't know what type of coroutine_handle<> to test it with. | |
// So we just check for await_ready/await_resume and assume if it has both | |
// of those then it will also have the await_suspend() method. | |
template<typename T> | |
concept awaiter = | |
requires(T& x) { | |
(x.await_ready() ? (void)0 : (void)0); | |
x.await_resume(); | |
}; | |
template<typename T> | |
concept _member_co_await = | |
requires(T&& x) { | |
{ static_cast<T&&>(x).operator co_await() } -> awaiter; | |
}; | |
template<typename T> | |
concept _free_co_await = | |
requires (T&& x) { | |
{ operator co_await(static_cast<T&&>(x)) } -> awaiter; | |
}; | |
template<typename T> | |
concept awaitable = | |
_member_co_await<T> || _free_co_await<T> || awaiter<T>; | |
// get_awaiter(x) -> awaiter | |
// | |
// Helper function that tries to mimic what the compiler does in `co_await` | |
// expressions to obtain the awaiter for a given awaitable argument. | |
// | |
// It's not a perfect match, however, as we can't exactly match the overload | |
// resolution which combines both free-function overloads and member-function overloads | |
// of `operator co_await()` into a single overload-set. | |
// | |
// The `get_awaiter()` function will be an ambiguous call if a type has both | |
// a free-function `operator co_await()` and a member-function `operator co_await()` | |
// even if the compiler's overload resolution would not consider this to be | |
// ambiguous. | |
template<typename T> | |
requires _member_co_await<T> | |
decltype(auto) get_awaiter(T&& x) noexcept(noexcept(static_cast<T&&>(x).operator co_await())) { | |
return static_cast<T&&>(x).operator co_await(); | |
} | |
template<typename T> | |
requires _free_co_await<T> | |
decltype(auto) get_awaiter(T&& x) noexcept(operator co_await(static_cast<T&&>(x))) { | |
return operator co_await(static_cast<T&&>(x)); | |
} | |
template<typename T> | |
requires awaiter<T> && (!_free_co_await<T> && !_member_co_await<T>) | |
T&& get_awaiter(T&& x) noexcept { | |
return static_cast<T&&>(x); | |
} | |
template<typename T> | |
requires awaitable<T> | |
using awaiter_type_t = decltype(get_awaiter(std::declval<T>())); | |
template<typename T> | |
requires awaitable<T> | |
using await_result_t = decltype(std::declval<awaiter_type_t<T>&>().await_resume()); | |
/////////////////////////////////////////////////// | |
// task<T> - basic async task type | |
template<typename T> | |
struct task; | |
template<typename T> | |
struct task_promise { | |
task<T> get_return_object() noexcept; | |
std::suspend_always initial_suspend() noexcept { return {}; } | |
struct final_awaiter { | |
bool await_ready() noexcept { return false; } | |
std::coroutine_handle<> await_suspend(std::coroutine_handle<task_promise> h) noexcept { | |
return h.promise().continuation; | |
} | |
[[noreturn]] void await_resume() noexcept { | |
std::terminate(); | |
} | |
}; | |
final_awaiter final_suspend() noexcept { return {}; } | |
template<typename U> | |
requires std::convertible_to<U, T> | |
void return_value(U&& value) noexcept(std::is_nothrow_constructible_v<T, U>) { | |
result.template emplace<1>(std::forward<U>(value)); | |
} | |
void unhandled_exception() noexcept { | |
result.template emplace<2>(std::current_exception()); | |
} | |
std::coroutine_handle<> continuation; | |
std::variant<std::monostate, T, std::exception_ptr> result; | |
}; | |
template<> | |
struct task_promise<void> { | |
task<void> get_return_object() noexcept; | |
std::suspend_always initial_suspend() noexcept { return {}; } | |
struct final_awaiter { | |
bool await_ready() noexcept { return false; } | |
std::coroutine_handle<> await_suspend(std::coroutine_handle<task_promise> h) noexcept { | |
return h.promise().continuation; | |
} | |
[[noreturn]] void await_resume() noexcept { | |
std::terminate(); | |
} | |
}; | |
final_awaiter final_suspend() noexcept { return {}; } | |
void return_void() noexcept { | |
result.emplace<1>(); | |
} | |
void unhandled_exception() noexcept { | |
result.emplace<2>(std::current_exception()); | |
} | |
struct empty {}; | |
std::coroutine_handle<> continuation; | |
std::variant<std::monostate, empty, std::exception_ptr> result; | |
}; | |
template<typename T> | |
struct [[nodiscard]] task { | |
private: | |
using handle_t = std::coroutine_handle<task_promise<T>>; | |
handle_t coro; | |
struct awaiter { | |
handle_t coro; | |
bool await_ready() noexcept { return false; } | |
handle_t await_suspend(std::coroutine_handle<> h) noexcept { | |
coro.promise().continuation = h; | |
return coro; | |
} | |
T await_resume() { | |
if (coro.promise().result.index() == 2) { | |
std::rethrow_exception(std::get<2>(std::move(coro.promise().result))); | |
} | |
assert(coro.promise().result.index() == 1); | |
if constexpr (!std::is_void_v<T>) { | |
return std::get<1>(std::move(coro.promise().result)); | |
} | |
} | |
}; | |
friend class task_promise<T>; | |
explicit task(handle_t h) noexcept : coro(h) {} | |
public: | |
using promise_type = task_promise<T>; | |
task(task&& other) noexcept : coro(std::exchange(other.coro, {})) {} | |
~task() { if (coro) coro.destroy(); } | |
awaiter operator co_await() && { | |
return awaiter{coro}; | |
} | |
}; | |
template<typename T> | |
task<T> task_promise<T>::get_return_object() noexcept { | |
return task<T>{std::coroutine_handle<task_promise<T>>::from_promise(*this)}; | |
} | |
task<void> task_promise<void>::get_return_object() noexcept { | |
return task<void>{std::coroutine_handle<task_promise<void>>::from_promise(*this)}; | |
} | |
//////////////////////////////////////// | |
// async_scope | |
// | |
// Used to launch new tasks and then later wait until all tasks have completed. | |
struct async_scope { | |
private: | |
struct detached_task { | |
struct promise_type { | |
async_scope& scope; | |
promise_type(async_scope& scope, auto&) noexcept : scope(scope) {} | |
detached_task get_return_object() noexcept { return {}; } | |
std::suspend_never initial_suspend() noexcept { | |
scope.add_ref(); | |
return {}; | |
} | |
struct final_awaiter { | |
bool await_ready() noexcept { return false; } | |
void await_suspend(std::coroutine_handle<promise_type> h) noexcept { | |
async_scope& s = h.promise().scope; | |
h.destroy(); | |
s.notify_task_finished(); | |
} | |
void await_resume() noexcept {} | |
}; | |
final_awaiter final_suspend() noexcept { return {}; } | |
void return_void() noexcept {} | |
[[noreturn]] void unhandled_exception() noexcept { std::terminate(); } | |
}; | |
}; | |
template<typename A> | |
detached_task spawn_detached_impl(A a) { | |
co_await std::forward<A>(a); | |
} | |
void add_ref() noexcept { | |
ref_count.fetch_add(ref_increment, std::memory_order_relaxed); | |
} | |
void notify_task_finished() noexcept { | |
std::size_t oldValue = ref_count.load(std::memory_order_acquire); | |
assert(oldValue >= ref_increment); | |
if (oldValue > (joiner_flag + ref_increment)) { | |
oldValue = ref_count.fetch_sub(ref_increment, std::memory_order_acq_rel); | |
} | |
if (oldValue == (joiner_flag + ref_increment)) { | |
// last ref and there is a joining coroutine -> resume the coroutien | |
joiner.resume(); | |
} | |
} | |
struct join_awaiter { | |
async_scope& scope; | |
bool await_ready() { | |
return scope.ref_count.load(std::memory_order_acquire) == 0; | |
} | |
bool await_suspend(std::coroutine_handle<> h) noexcept { | |
scope.joiner = h; | |
std::size_t oldValue = scope.ref_count.fetch_add(joiner_flag, std::memory_order_acq_rel); | |
return (oldValue != 0); | |
} | |
void await_resume() noexcept {} | |
}; | |
static constexpr std::size_t joiner_flag = 1; | |
static constexpr std::size_t ref_increment = 2; | |
std::atomic<std::size_t> ref_count{0}; | |
std::coroutine_handle<> joiner; | |
public: | |
template<typename A> | |
requires decay_copyable<A> && awaitable<std::decay_t<A>> | |
void spawn_detached(A&& a) { | |
spawn_detached_impl(std::forward<A>(a)); | |
} | |
[[nodiscard]] join_awaiter join_async() noexcept { | |
return join_awaiter{*this}; | |
} | |
}; | |
//////////////////////////////////////////////////////////////// | |
// sync_wait() | |
template<typename Task> | |
await_result_t<Task> sync_wait(Task&& t) { | |
struct _void {}; | |
using return_type = await_result_t<Task>; | |
using storage_type = std::add_pointer_t<std::conditional_t< | |
std::is_void_v<return_type>, _void, return_type>>; | |
using result_type = std::variant<std::monostate, storage_type, std::exception_ptr>; | |
struct _sync_task { | |
struct promise_type { | |
std::binary_semaphore sem{0}; | |
result_type result; | |
_sync_task get_return_object() noexcept { | |
return _sync_task{std::coroutine_handle<promise_type>::from_promise(*this)}; | |
} | |
struct final_awaiter { | |
bool await_ready() noexcept { return false; } | |
void await_suspend(std::coroutine_handle<promise_type> h) noexcept { | |
// Now that coroutine has suspended we can signal the semaphore, | |
// unblocking the waiting thread. The other thread will then | |
// destroy this coroutine (which is safe because it is suspended). | |
h.promise().sem.release(); | |
} | |
void await_resume() noexcept {} | |
}; | |
std::suspend_always initial_suspend() noexcept { return {}; } | |
final_awaiter final_suspend() noexcept { return {}; } | |
using non_void_return_type = std::conditional_t<std::is_void_v<return_type>, _void, return_type>; | |
final_awaiter yield_value(non_void_return_type&& x) requires (!std::is_void_v<return_type>) { | |
// Note that we just store the address here and then suspend | |
// and unblock the waiting thread which then copies/moves the | |
// result from this address directly to the return-value of | |
// sync_wait(). This avoids having to make an extra intermediate | |
// copy of the result value. | |
result.template emplace<1>(std::addressof(x)); | |
return {}; | |
} | |
void return_void() noexcept { | |
result.template emplace<1>(); | |
} | |
void unhandled_exception() noexcept { | |
result.template emplace<2>(std::current_exception()); | |
} | |
}; | |
using handle_t = std::coroutine_handle<promise_type>; | |
handle_t coro; | |
explicit _sync_task(handle_t h) noexcept : coro(h) {} | |
_sync_task(_sync_task&& o) noexcept : coro(std::exchange(o.coro, {})) {} | |
~_sync_task() { if (coro) coro.destroy(); } | |
// The force-inline here is required to get the _sync_task coroutine elided. | |
// Otherwise the compiler doesn't know that this function hasn't modified 'coro' | |
// member variable and so can't deduce that it's always destroyed before sync_wait() | |
// returns. | |
FORCE_INLINE return_type run() { | |
coro.resume(); | |
coro.promise().sem.acquire(); | |
auto& result = coro.promise().result; | |
if (result.index() == 2) { | |
std::rethrow_exception(std::get<2>(std::move(result))); | |
} | |
assert(result.index() == 1); | |
if constexpr (!std::is_void_v<return_type>) { | |
return static_cast<return_type&&>(*std::get<1>(result)); | |
} | |
} | |
}; | |
return [&]() -> _sync_task { | |
if constexpr (std::is_void_v<return_type>) { | |
co_await static_cast<Task&&>(t); | |
} else { | |
// use co_yield instead of co_return so we suspend while the | |
// potentially temporary result of co_await is still alive. | |
co_yield co_await static_cast<Task&&>(t); | |
} | |
}().run(); | |
} | |
///////////////////////////////////////////////// | |
// manual_event_loop | |
// | |
// A simple scheduler context with an intrusive list. | |
// | |
// Uses mutex/condition_variable for synchronisation and supports | |
// multiple work threads running tasks. | |
struct manual_event_loop { | |
private: | |
struct queue_item { | |
queue_item* next; | |
std::coroutine_handle<> coro; | |
}; | |
std::mutex mut; | |
std::condition_variable cv; | |
queue_item* head{nullptr}; | |
queue_item* tail{nullptr}; | |
void enqueue(queue_item* item) noexcept { | |
std::lock_guard lock{mut}; | |
item->next = nullptr; | |
if (head == nullptr) { | |
head = item; | |
} else { | |
tail->next = item; | |
} | |
tail = item; | |
cv.notify_one(); | |
} | |
queue_item* pop_item() noexcept { | |
queue_item* front = head; | |
if (head != nullptr) { | |
head = front->next; | |
if (head == nullptr) { | |
tail = nullptr; | |
} | |
} | |
return front; | |
} | |
struct schedule_awaitable { | |
manual_event_loop* loop; | |
queue_item item; | |
explicit schedule_awaitable(manual_event_loop& loop) noexcept : loop(&loop) {} | |
bool await_ready() noexcept { return false; } | |
void await_suspend(std::coroutine_handle<> coro) noexcept { | |
item.coro = coro; | |
loop->enqueue(&item); | |
} | |
void await_resume() noexcept {} | |
}; | |
public: | |
schedule_awaitable schedule() noexcept { return schedule_awaitable{*this}; } | |
void run(std::stop_token st) noexcept { | |
std::stop_callback cb{st, [&]() noexcept { | |
std::lock_guard lock{mut}; | |
cv.notify_all(); | |
}}; | |
std::unique_lock lock{mut}; | |
while (true) { | |
cv.wait(lock, [&]() noexcept { | |
return head != nullptr || st.stop_requested(); | |
}); | |
if (st.stop_requested()) { | |
return; | |
} | |
queue_item* item = pop_item(); | |
lock.unlock(); | |
item->coro.resume(); | |
lock.lock(); | |
} | |
} | |
}; | |
//////////////////////////////////////////////// | |
// Helper for improving allocation elision for composed operations. | |
// | |
// Instead of doing something like: | |
// | |
// task<T> h(int arg); | |
// scope.spawn_detached(h(42)); | |
// | |
// which will generally separately allocate the h() coroutine as well | |
// as the internal detached_task coroutine, if we write: | |
// | |
// scope.spawn_detached(lazy_task{[] { return h(42); }}); | |
// | |
// then this defers calling the `h()` coroutine function to the evaluation | |
// of `operator co_await()` in the `detached_task` coroutine, which then | |
// permits the compiler to elide the allocation of `h()` coroutine and | |
// combine its storage into the `detached_task` coroutine state, meaning | |
// that we now have one allocation per spawned task instead of two. | |
template<typename F> | |
struct lazy_task { | |
F func; | |
using task_t = std::invoke_result_t<F&>; | |
using awaiter_t = awaiter_type_t<task_t>; | |
struct awaiter { | |
task_t task; | |
awaiter_t inner; | |
explicit awaiter(F& func) noexcept(std::is_nothrow_invocable_v<F&> && noexcept(get_awaiter(static_cast<task_t&&>(task)))) | |
: task(func()) | |
, inner(get_awaiter(static_cast<task_t&&>(task))) | |
{} | |
decltype(auto) await_ready() noexcept(noexcept(inner.await_ready())) { | |
return inner.await_ready(); | |
} | |
decltype(auto) await_suspend(auto h) noexcept(noexcept(inner.await_suspend(h))) { | |
return inner.await_suspend(h); | |
} | |
decltype(auto) await_resume() noexcept(noexcept(inner.await_resume())) { | |
return inner.await_resume(); | |
} | |
}; | |
awaiter operator co_await() noexcept(std::is_nothrow_constructible_v<awaiter, F&>) { | |
return awaiter{func}; | |
} | |
}; | |
template<typename F> | |
lazy_task(F) -> lazy_task<F>; | |
///////////////////////////////////////////////// | |
// Example code | |
#include <unistd.h> | |
static task<int> f(int i) { | |
using namespace std::chrono_literals; | |
std::this_thread::sleep_for(1ms); | |
co_return i; | |
} | |
static task<int> g(int i, manual_event_loop& loop) { | |
co_await loop.schedule(); | |
int x = co_await f(i); | |
co_return x + 1; | |
} | |
static task<void> h(int i, manual_event_loop& loop) { | |
int x = co_await g(i, loop); | |
auto ts = std::chrono::steady_clock::now().time_since_epoch().count(); | |
std::printf("[%u] %i -> %i (on %i)\n", (unsigned int)ts, i, x, (int)::gettid()); | |
} | |
static task<void> nested_scopes(int x, manual_event_loop& loop) { | |
co_await loop.schedule(); | |
async_scope scope; | |
try { | |
for (int i = 0; i < 10; ++i) { | |
scope.spawn_detached(h(i, loop)); | |
} | |
} catch (...) { | |
std::printf("failure!\n"); | |
} | |
co_await scope.join_async(); | |
std::printf("nested %i done\n", x); | |
std::fflush(stdout); | |
} | |
int main() { | |
manual_event_loop loop; | |
std::jthread thd{[&](std::stop_token st) { loop.run(st); }}; | |
std::jthread thd2{[&](std::stop_token st) { loop.run(st); }}; | |
std::printf("starting example\n"); | |
{ | |
async_scope scope; | |
scope_guard join_on_exit{[&] { sync_wait(scope.join_async()); }}; | |
for (int i = 0; i < 10; ++i) { | |
// Use lazy_task here so that h() coroutine allocation is elided | |
// and incorporated into spawn_detached() allocation. | |
scope.spawn_detached(lazy_task{[i, &loop] { | |
return h(i, loop); | |
}}); | |
} | |
} | |
std::printf("starting nested_scopes example\n"); | |
{ | |
async_scope scope; | |
scope_guard join_on_exit{[&] { sync_wait(scope.join_async()); }}; | |
for (int i = 0; i < 10; ++i) { | |
scope.spawn_detached(lazy_task{[i, &loop] { | |
return nested_scopes(i, loop); | |
}}); | |
} | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment