Last active
March 11, 2022 21:39
-
-
Save bonzini/78f37bd562e1e18f7bd214dd94bcbea7 to your computer and use it in GitHub Desktop.
Simple C++ coroutine runtime
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "coro.h" | |
#include <cstdio> | |
static __thread Coroutine *current; | |
void Yield::await_suspend(std::coroutine_handle<> parent) const noexcept { | |
//printf("!!!! top = %p, yielding from %p\n", current->top, parent); | |
current->top = parent; | |
} | |
// RAII wrapper to set and restore the current coroutine | |
struct WithCurrent { | |
Coroutine &_co; | |
WithCurrent(Coroutine &co): _co(co) { | |
_co.caller = current; | |
current = &_co; | |
} | |
~WithCurrent() { | |
Coroutine *co = current; | |
current = _co.caller; | |
_co.caller = nullptr; | |
} | |
}; | |
void Coroutine::resume() { | |
auto w = WithCurrent(*this); | |
std::coroutine_handle<> old_top = top; | |
//printf("$$$$ resume %p %d\n", old_top.address(), old_top.done()); | |
top = nullptr; | |
old_top.resume(); | |
} | |
// --------------------------- | |
#include <cstdio> | |
void qemu_coroutine_enter(Coroutine *co) | |
{ | |
co->resume(); | |
if (!co->top) { | |
//printf("$$$$ deleting\n"); | |
delete co; | |
} | |
} | |
// Change the type from CoroutineFn<void> to Coroutine, | |
// so that it does not start until qemu_coroutine_enter() | |
Coroutine coroutine_trampoline(CoroutineFunc *func, void *opaque) | |
{ | |
co_await func(opaque); | |
} | |
Coroutine *qemu_coroutine_create(CoroutineFunc *func, void *opaque) | |
{ | |
return new Coroutine(coroutine_trampoline(func, opaque)); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <cstdint> | |
#include <cstdio> | |
#include <coroutine> | |
#include <exception> | |
struct Coroutine; | |
extern "C" { | |
void qemu_coroutine_enter(Coroutine *co); | |
} | |
// BaseCoroutine is a simple wrapper type for a Promise. It mostly | |
// exists because C++ says so, but it also provides two extra features: | |
// RAII destruction of the coroutine (which is more efficient but | |
// beware, the promise's final_suspend must always suspend to avoid | |
// double free) and a cast to std::coroutine_handle<>, which makes | |
// it resumable. | |
template<typename Promise> struct BaseCoroutine | |
{ | |
using promise_type = Promise; | |
BaseCoroutine() = default; | |
explicit BaseCoroutine (Promise &promise) : | |
_coroutine{std::coroutine_handle<Promise>::from_promise(promise)} {} | |
BaseCoroutine(BaseCoroutine const&) = delete; | |
BaseCoroutine(BaseCoroutine&& other) : _coroutine{other._coroutine} { | |
other._coroutine = nullptr; | |
} | |
BaseCoroutine& operator=(BaseCoroutine const&) = delete; | |
BaseCoroutine& operator=(BaseCoroutine&& other) { | |
if (&other != this) { | |
_coroutine = other._coroutine; | |
other._coroutine = nullptr; | |
} | |
return *this; | |
} | |
~BaseCoroutine() { | |
//printf("!!!! destroying %p\n", _coroutine); | |
if (_coroutine) _coroutine.destroy(); | |
} | |
operator bool() const noexcept { | |
return _coroutine; | |
} | |
operator std::coroutine_handle<>() const noexcept { | |
return _coroutine; | |
} | |
Promise &promise() const noexcept { | |
return _coroutine.promise(); | |
} | |
private: | |
std::coroutine_handle<Promise> _coroutine = nullptr; | |
}; | |
// This is a simple awaitable object that takes care of resuming a | |
// parent coroutine. It's needed because co_await suspends all | |
// parent coroutines on the stack. It does not need a specific | |
// "kind" of coroutine_handle, so no need to put it inside the | |
// templates below. | |
// | |
// If next is NULL, then this degrades to std::suspend_always. | |
struct ResumeAndFinish { | |
explicit ResumeAndFinish(std::coroutine_handle<> next) noexcept : | |
_next{next} {} | |
bool await_ready() const noexcept { | |
return false; | |
} | |
bool await_suspend(std::coroutine_handle<> ch) const noexcept { | |
if (_next) { | |
_next.resume(); | |
} | |
return true; | |
} | |
void await_resume() const noexcept {} | |
private: | |
std::coroutine_handle<> _next; | |
}; | |
// ------------------------ | |
// Coroutine is the entry point into a coroutine. It stores the | |
// coroutine_handle that last called qemu_coroutine_yield(), and | |
// Coroutine::resume() then resumes from the last yield point. | |
// | |
// Together with a thread-local variable "current", the "caller" | |
// member establishes a stack of active coroutines, so that | |
// qemu_coroutine_yield() knows which coroutine has yielded. | |
// | |
// Its promise type, EntryPromise, is pretty much bog-standard. | |
// It always suspends on entry, so that the coroutine is only | |
// entered by the first call to qemu_coroutine_enter(); and it | |
// always suspends on exit too, because we want to clean up the | |
// coroutine explicitly in BaseCoroutine's destructor. | |
struct EntryPromise; | |
struct Coroutine: BaseCoroutine<EntryPromise> { | |
Coroutine *caller = nullptr; | |
std::coroutine_handle<> top; | |
explicit Coroutine(promise_type &promise) : | |
BaseCoroutine{promise}, top{*this} {} | |
void resume(); | |
}; | |
struct EntryPromise | |
{ | |
Coroutine get_return_object() noexcept { return Coroutine{*this}; } | |
void unhandled_exception() { std::terminate(); } | |
auto initial_suspend() const noexcept { return std::suspend_always{}; } | |
auto final_suspend() const noexcept { return std::suspend_always{}; } | |
void return_void() const noexcept {} | |
}; | |
// ------------------------ | |
// CoroutineFn does not even need anything more than what | |
// BaseCoroutine provides, so it's just a type alias. The magic | |
// is all in ValuePromise<T>. | |
// | |
// Suspended CoroutineFns are chained between themselves. Whenever a | |
// coroutine is suspended, all those that have done a co_await are | |
// also suspended, and whenever a coroutine finishes, it has to | |
// check if its parent can now be resumed. | |
// | |
// The two auxiliary classes Awaiter and ResumeAndFinish take | |
// care of the two sides of this. Awaiter's await_suspend() stores | |
// the parent coroutine into ValuePromise; ResumeAndFinish's runs | |
// after a coroutine returns, and resumes the parent coroutine. | |
template<typename T> struct ValuePromise; | |
template<typename T> | |
using CoroutineFn = BaseCoroutine<ValuePromise<T>>; | |
typedef CoroutineFn<void> CoroutineFunc(void *); | |
// Unfortunately it is forbidden to define both return_void() and | |
// return_value() in the same class. In order to cut on the | |
// code duplication, define a superclass for both ValuePromise<T> | |
// and ValuePromise<void>. | |
// | |
// The "curiously recurring template pattern" is used to substitute | |
// ValuePromise<T> into the methods of the base class and its Awaited. | |
// For example await_resume() needs to retrieve a value with the | |
// correct type from the subclass's value() method. | |
template<typename T, typename Derived> | |
struct BasePromise | |
{ | |
using coro_handle_type = std::coroutine_handle<Derived>; | |
#if 0 | |
// Same as get_return_object().address() but actually works. | |
// Useful as an identifier to identify the promise in debugging | |
// output, because it matches the values passed to await_suspend(). | |
void *coro_address() const { | |
return __builtin_coro_promise((char *)this, __alignof(*this), true); | |
} | |
BasePromise() { | |
printf("!!!! created %p\n", coro_address()); | |
} | |
~BasePromise() { | |
printf("!!!! destroyed %p\n", coro_address()); | |
} | |
#endif | |
CoroutineFn<T> get_return_object() noexcept { return CoroutineFn<T>{downcast()}; } | |
void unhandled_exception() { std::terminate(); } | |
auto initial_suspend() const noexcept { return std::suspend_never{}; } | |
auto final_suspend() noexcept { | |
auto continuation = ResumeAndFinish{_next}; | |
mark_ready(); | |
return continuation; | |
} | |
private: | |
std::coroutine_handle<> _next = nullptr; | |
static const std::uintptr_t READY_MARKER = 1; | |
void mark_ready() { | |
_next = std::coroutine_handle<>::from_address((void *)READY_MARKER); | |
} | |
bool is_ready() const { | |
return _next.address() == (void *)READY_MARKER; | |
} | |
Derived& downcast() noexcept { return *static_cast<Derived*>(this); } | |
Derived const& downcast() const noexcept { return *static_cast<const Derived*>(this); } | |
// This records the parent coroutine, before a co_await suspends | |
// all parent coroutines on the stack. | |
void then(std::coroutine_handle<> parent) { _next = parent; } | |
// This is the object that lets us co_await a CoroutineFn<T> (of which | |
// this class is the corresponding promise object). This is just mapping | |
// C++ awaitable naming into the more conventional promise naming. | |
struct Awaiter { | |
Derived &_promise; | |
explicit Awaiter(Derived &promise) : _promise{promise} {} | |
bool await_ready() const noexcept { | |
return _promise.is_ready(); | |
} | |
void await_suspend(std::coroutine_handle<> parent) const noexcept { | |
_promise.then(parent); | |
} | |
Derived::await_resume_type await_resume() const noexcept { | |
return _promise.value(); | |
} | |
}; | |
// C++ connoisseurs will tell you that this is not private. | |
friend Awaiter operator co_await(CoroutineFn<T> co) { | |
return Awaiter{co.promise()}; | |
} | |
}; | |
// The actual promises, respectively for non-void and void types. | |
// All that's left is storing and retrieving the value. | |
template<typename T> | |
struct ValuePromise: BasePromise<T, ValuePromise<T>> | |
{ | |
using await_resume_type = T&&; | |
T _value; | |
void return_value(T&& value) { _value = std::move(value); } | |
void return_value(T const& value) { _value = value; } | |
T&& value() noexcept { return static_cast<T&&>(_value); } | |
}; | |
template<> | |
struct ValuePromise<void>: BasePromise<void, ValuePromise<void>> | |
{ | |
using await_resume_type = void; | |
void return_void() const {} | |
void value() const {} | |
}; | |
// --------------------------- | |
// This class takes care of yielding, which is just a matter of doing | |
// "co_await Yield{}". This always suspends, and also stores the | |
// suspending CoroutineFn in current->top. | |
struct Yield: std::suspend_always { | |
void await_suspend(std::coroutine_handle<> parent) const noexcept; | |
}; | |
// --------------------------- | |
Coroutine *qemu_coroutine_create(CoroutineFunc *func, void *opaque); | |
// Make it possible to write "co_await qemu_coroutine_yield()" | |
static inline Yield qemu_coroutine_yield() | |
{ | |
return Yield{}; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "coro.h" | |
#include <iostream> | |
#include <string> | |
CoroutineFn<int> return_int() { | |
std::cout << ">>suspending to " << __func__ << '\n'; | |
co_await qemu_coroutine_yield(); | |
std::cout << ">>back\n"; | |
co_return 30; | |
} | |
CoroutineFn<void> return_void() { | |
std::cout << ">>suspending to " << __func__ << '\n'; | |
co_await qemu_coroutine_yield(); | |
std::cout << ">>back\n"; | |
} | |
CoroutineFn<void> co(void *) { | |
co_await return_void(); | |
std::cout << co_await return_int() << '\n'; | |
std::cout << "suspending\n"; | |
co_await qemu_coroutine_yield(); | |
std::cout << "back\n"; | |
} | |
int main() { | |
auto f = qemu_coroutine_create(co, NULL); | |
std::cout << "--- 0\n"; | |
qemu_coroutine_enter(f); | |
std::cout << "--- 1\n"; | |
qemu_coroutine_enter(f); | |
std::cout << "--- 2\n"; | |
qemu_coroutine_enter(f); | |
std::cout << "--- 3\n"; | |
qemu_coroutine_enter(f); | |
std::cout << "--- 4\n"; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "coro.h" | |
#include <iostream> | |
#include <string> | |
CoroutineFn<std::string> yield_and_resume(const char *s) | |
{ | |
co_await qemu_coroutine_yield(); | |
co_return s; | |
} | |
CoroutineFn<void> counter(void *opaque) | |
{ | |
std::cout << co_await yield_and_resume("counter: resumed (#1)\n"); | |
co_await qemu_coroutine_yield(); | |
std::cout << co_await yield_and_resume("counter: resumed (#2)\n"); | |
} | |
int main () | |
{ | |
std::cout << "main: calling counter\n"; | |
Coroutine *the_counter = qemu_coroutine_create(counter, NULL); | |
qemu_coroutine_enter(the_counter); // in resumed() | |
qemu_coroutine_enter(the_counter); | |
qemu_coroutine_enter(the_counter); // in resumed() | |
qemu_coroutine_enter(the_counter); | |
std::cout << "main: done\n"; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "coro.h" | |
#include <iostream> | |
#include <string> | |
CoroutineFn<std::string> resumed(const char *s) | |
{ | |
co_return s; | |
} | |
CoroutineFn<void> counter(void *opaque) | |
{ | |
std::cout << co_await resumed("counter: resumed (#1)\n"); | |
co_await qemu_coroutine_yield(); | |
std::cout << co_await resumed("counter: resumed (#2)\n"); | |
} | |
int main () | |
{ | |
std::cout << "main: calling counter\n"; | |
Coroutine *the_counter = qemu_coroutine_create(counter, NULL); | |
qemu_coroutine_enter(the_counter); | |
qemu_coroutine_enter(the_counter); | |
std::cout << "main: done\n"; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "coro.h" | |
#include <iostream> | |
#include <string> | |
#include <vector> | |
CoroutineFn<std::vector<int>> resumed() | |
{ | |
auto x = std::vector<int>(); | |
x.push_back(1); | |
x.push_back(2); | |
x.push_back(3); | |
co_return x; | |
} | |
CoroutineFn<void> vec(void *opaque) | |
{ | |
std::vector<int> v = co_await resumed(); | |
std::cout << v.back() << '\n'; v.pop_back(); | |
std::cout << v.back() << '\n'; v.pop_back(); | |
std::cout << v.back() << '\n'; v.pop_back(); | |
} | |
int main () | |
{ | |
Coroutine *co = qemu_coroutine_create(vec, NULL); | |
qemu_coroutine_enter(co); | |
std::cout << "main: done\n"; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment