Skip to content

Instantly share code, notes, and snippets.

@Jackarain
Last active September 29, 2024 13:52
Show Gist options
  • Save Jackarain/acb9a782ed8425c0c210ffb406c8adc0 to your computer and use it in GitHub Desktop.
Save Jackarain/acb9a782ed8425c0c210ffb406c8adc0 to your computer and use it in GitHub Desktop.
c++ 20 协程简易实现
#pragma once
#include <coroutine>
#include <functional>
#include <type_traits>
#if defined(DEBUG) || defined(_DEBUG)
#include <unordered_set>
std::unordered_set<void*> global_crors;
#endif
namespace cppcoro
{
template <typename T>
struct awaitable;
template <typename T>
struct awaitable_promise;
//////////////////////////////////////////////////////////////////////////
struct awaitable_detached {
struct promise_type {
std::suspend_never initial_suspend() noexcept { return {}; }
std::suspend_never final_suspend() noexcept { return {}; }
void return_void() noexcept {}
void unhandled_exception() {}
awaitable_detached get_return_object() noexcept {
return awaitable_detached();
}
#if defined(DEBUG) || defined(_DEBUG)
void* operator new(std::size_t size) {
void* ptr = malloc(size);
if (!ptr) throw std::bad_alloc{};
global_crors.insert(ptr);
return ptr;
}
void operator delete(void* ptr, std::size_t size) {
global_crors.erase(ptr);
(void)size;
free(ptr);
}
#endif
};
};
//////////////////////////////////////////////////////////////////////////
template <typename T>
struct final_awaitable {
bool await_ready() noexcept { return false; }
void await_resume() noexcept {}
std::coroutine_handle<> await_suspend(
std::coroutine_handle<awaitable_promise<T>> h) noexcept {
if (h.promise().continuation_)
return h.promise().continuation_;
return std::noop_coroutine();
}
};
//////////////////////////////////////////////////////////////////////////
// 返回 T 的协程 awaitable_promise 实现.
// Promise 类型实现...
template <typename T>
struct awaitable_promise {
awaitable<T> get_return_object();
auto initial_suspend() {
return std::suspend_always{};
}
auto final_suspend() noexcept {
return final_awaitable<T>{};
}
void unhandled_exception() {}
template <typename V>
void return_value(V&& v) noexcept {
value = std::forward<V>(v);
}
void reset_handle(std::coroutine_handle<> h) {
continuation_ = h;
}
#if defined(DEBUG) || defined(_DEBUG)
void* operator new(std::size_t size) {
void* ptr = malloc(size);
if (!ptr) throw std::bad_alloc{};
global_crors.insert(ptr);
return ptr;
}
void operator delete(void* ptr, std::size_t size) {
global_crors.erase(ptr);
(void)size;
free(ptr);
}
#endif
std::coroutine_handle<> continuation_;
T value; // 用于存储协程返回的值
};
//////////////////////////////////////////////////////////////////////////
// 返回 void 的协程偏特化 awaitable_promise 实现
template <>
struct awaitable_promise<void> {
awaitable<void> get_return_object();
auto initial_suspend() {
return std::suspend_always{};
}
auto final_suspend() noexcept {
return final_awaitable<void>{};
}
void unhandled_exception() {}
void return_void() {}
void reset_handle(std::coroutine_handle<> h) {
continuation_ = h;
}
#if defined(DEBUG) || defined(_DEBUG)
void* operator new(std::size_t size) {
void* ptr = malloc(size);
if (!ptr) throw std::bad_alloc{};
global_crors.insert(ptr);
return ptr;
}
void operator delete(void* ptr, std::size_t size) {
global_crors.erase(ptr);
(void)size;
free(ptr);
}
#endif
std::coroutine_handle<> continuation_;
};
//////////////////////////////////////////////////////////////////////////
// awaitable 协程包装...
template <typename T>
struct awaitable {
using promise_type = awaitable_promise<T>;
awaitable(std::coroutine_handle<promise_type> h)
: current_coro_handle_(h)
{}
~awaitable() {
if (current_coro_handle_ && current_coro_handle_.done())
current_coro_handle_.destroy();
}
awaitable(awaitable&& t) noexcept : current_coro_handle_(t.current_coro_handle_) {
t.current_coro_handle_ = nullptr;
}
awaitable& operator=(awaitable&& t) noexcept {
if (&t != this) {
if (current_coro_handle_) current_coro_handle_.destroy();
current_coro_handle_ = t.current_coro_handle_;
t.current_coro_handle_ = nullptr;
}
return *this;
}
awaitable(const awaitable&) = delete;
awaitable& operator=(const awaitable&) = delete;
T operator()() {
return get();
}
T get() {
if constexpr (!std::is_same_v<T, void>)
return std::move(current_coro_handle_.promise().value);
}
bool await_ready() const noexcept {
return false;
}
T await_resume() {
if constexpr (std::is_void_v<T>) {
current_coro_handle_.destroy();
current_coro_handle_ = nullptr;
} else {
auto ret = std::move(current_coro_handle_.promise().value);
current_coro_handle_.destroy();
current_coro_handle_ = nullptr;
return ret;
}
}
auto await_suspend(std::coroutine_handle<> continuation) {
current_coro_handle_.promise().reset_handle(continuation);
return current_coro_handle_;
}
void detach() {
auto launch_coro = [](awaitable<T> lazy) -> awaitable_detached {
co_await lazy;
};
[[maybe_unused]] auto detached = launch_coro(std::move(*this));
}
std::coroutine_handle<promise_type> current_coro_handle_;
};
//////////////////////////////////////////////////////////////////////////
template <typename T>
awaitable<T> awaitable_promise<T>::get_return_object() {
auto result = awaitable<T>{ std::coroutine_handle<awaitable_promise<T>>::from_promise(*this) };
return result;
}
awaitable<void> awaitable_promise<void>::get_return_object() {
auto result = awaitable<void>{ std::coroutine_handle<awaitable_promise<void>>::from_promise(*this) };
return result;
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T, typename CallbackFunction>
struct CallbackAwaiter
{
public:
CallbackAwaiter(CallbackFunction&& callback_function)
: callback_function_(std::move(callback_function)) {}
bool await_ready() noexcept { return false; }
auto await_suspend(std::coroutine_handle<> handle) {
callback_function_([this](T t) mutable {
result_ = std::move(t);
});
return handle;
}
T await_resume() noexcept {
return std::move(result_);
}
private:
CallbackFunction callback_function_;
T result_;
};
template<typename CallbackFunction>
struct CallbackAwaiter<void, CallbackFunction>
{
public:
CallbackAwaiter(CallbackFunction&& callback_function)
: callback_function_(std::move(callback_function))
{}
bool await_ready() noexcept { return false; }
auto await_suspend(std::coroutine_handle<> handle) {
callback_function_(
[this]()
{}
);
return handle;
}
void await_resume() noexcept {}
private:
CallbackFunction callback_function_;
};
template<typename T, typename callback>
CallbackAwaiter<T, callback>
callback_awaitable(callback&& cb) {
return CallbackAwaiter<T, callback>{std::forward<callback>(cb)};
}
//////////////////////////////////////////////////////////////////////////
template<typename T, typename CallbackFunction>
struct ManualAwaiter
{
public:
ManualAwaiter(CallbackFunction&& callback_function)
: callback_function_(std::move(callback_function)) {}
bool await_ready() noexcept { return false; }
void await_suspend(std::coroutine_handle<> handle) {
callback_function_([handle = std::move(handle), this](T t) mutable {
result_ = std::move(t);
handle.resume();
});
}
T await_resume() noexcept {
return std::move(result_);
}
private:
CallbackFunction callback_function_;
T result_;
};
template<typename CallbackFunction>
struct ManualAwaiter<void, CallbackFunction>
{
public:
ManualAwaiter(CallbackFunction&& callback_function)
: callback_function_(std::move(callback_function))
{}
bool await_ready() noexcept { return false; }
void await_suspend(std::coroutine_handle<> handle) {
callback_function_([handle = std::move(handle)]() mutable {
handle.resume();
});
}
void await_resume() noexcept {}
private:
CallbackFunction callback_function_;
};
template<typename T, typename callback>
ManualAwaiter<T, callback>
manual_awaitable(callback&& cb) {
return ManualAwaiter<T, callback>{std::forward<callback>(cb)};
}
template <typename Awaitable>
void coro_start(Awaitable&& a) {
a.detach();
}
@Jackarain
Copy link
Author

Jackarain commented Sep 29, 2024

用法示例(没有 executor 的情况):

cppcoro::awaitable<int> coro_compute_int(int value) {
	auto ret = co_await callback_awaitable<int>([value](auto handle)
		{
			std::cout << value << " value\n";
			handle(value * 100);
		});

	co_return (value + ret);
}

cppcoro::awaitable<void> coro_compute_exec(int value)
{
	auto ret = co_await coro_compute_int(value);
	std::cout << "return: " << ret << std::endl;
	co_return;
}

cppcoro::awaitable<void> coro_compute() {
	for (auto i = 0; i < 1000000; i++) {
		co_await coro_compute_exec(i);
	}
}

int main()
{
	coro_start(coro_compute());
	return 0;
}

@Jackarain
Copy link
Author

用法示例(有 executor 的情况):

boost::asio::io_context main_ioc;

cppcoro::awaitable<int> coro_compute_int(int value) {
	auto ret = co_await manual_awaitable<int>([value](auto handle)
		{
			main_ioc.post([value, handle = std::move(handle)]() mutable {
				std::this_thread::sleep_for(std::chrono::seconds(0));
				std::cout << value << " value\n";
				handle(value * 100);
			});
		});

	co_return (value + ret);
}

cppcoro::awaitable<void> coro_compute_exec(int value)
{
	auto ret = co_await coro_compute_int(value);
	std::cout << "return: " << ret << std::endl;
	co_return;
}

cppcoro::awaitable<void> coro_compute() {
	for (auto i = 0; i < 100; i++) {
		co_await coro_compute_exec(i);
	}
}

int main(int argc, char** argv)
{
	coro_start(coro_compute());
	main_ioc.run();
	return 0;
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment