Skip to content

Instantly share code, notes, and snippets.

@kammce
Created June 5, 2025 14:40
Show Gist options
  • Save kammce/38ba005273300236232e5720db50951e to your computer and use it in GitHub Desktop.
Save kammce/38ba005273300236232e5720db50951e to your computer and use it in GitHub Desktop.
Getting non-heap stack-allocating-like coroutines working
#include <algorithm>
#include <array>
#include <coroutine>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <exception>
#include <memory_resource>
#include <numeric>
#include <span>
#include <stdexcept>
#include <string>
#include <utility>
#include <variant>
namespace hal {
using byte = std::uint8_t;
using usize = std::size_t;
class async_context
{
public:
explicit async_context(std::pmr::memory_resource& p_resource)
: m_resource(&p_resource)
{
}
async_context() = default;
[[nodiscard]] constexpr auto* resource()
{
return m_resource;
}
constexpr auto last_allocation_size()
{
return m_last_allocation_size;
}
constexpr void last_allocation_size(usize p_last_allocation_size)
{
m_last_allocation_size = p_last_allocation_size;
}
constexpr std::coroutine_handle<> active_handle()
{
return m_active_handle;
}
constexpr void active_handle(std::coroutine_handle<> p_active_handle)
{
m_active_handle = p_active_handle;
}
private:
std::pmr::memory_resource* m_resource = nullptr;
usize m_last_allocation_size = 0;
std::coroutine_handle<> m_active_handle = std::noop_coroutine();
};
class task_promise_base
{
public:
// For regular functions
template<typename... Args>
static constexpr void* operator new(std::size_t p_size,
async_context& p_context,
Args&&...)
{
p_context.last_allocation_size(p_size);
return p_context.resource()->allocate(p_size);
}
// For member functions - handles the implicit 'this' parameter
template<typename Class, typename... Args>
static constexpr void* operator new(std::size_t p_size,
Class&, // The 'this' object
async_context& p_context,
Args&&...)
{
p_context.last_allocation_size(p_size);
return p_context.resource()->allocate(p_size);
}
// Add regular delete operators for normal coroutine destruction
static constexpr void operator delete(void*) noexcept
{
}
static constexpr void operator delete(void*, std::size_t) noexcept
{
}
// Constructor for regular functions
task_promise_base(async_context& p_context)
{
m_context = &p_context;
m_frame_size = p_context.last_allocation_size();
}
// Constructor for member functions (handles 'this' parameter)
template<typename Class>
task_promise_base(Class&, async_context& p_context)
{
m_context = &p_context;
m_frame_size = p_context.last_allocation_size();
}
// Generic constructor for additional parameters
template<typename... Args>
task_promise_base(async_context& p_context, Args&&...)
{
m_context = &p_context;
m_frame_size = p_context.last_allocation_size();
}
// Constructor for member functions with additional parameters
template<typename Class, typename... Args>
task_promise_base(Class&, async_context& p_context, Args&&...)
{
m_context = &p_context;
m_frame_size = p_context.last_allocation_size();
}
constexpr std::suspend_always initial_suspend() noexcept
{
return {};
}
template<typename U>
constexpr U&& await_transform(U&& awaitable) noexcept
{
return static_cast<U&&>(awaitable);
}
void unhandled_exception() noexcept
{
m_error = std::current_exception();
}
constexpr auto& context()
{
return *m_context;
}
constexpr void context(async_context& p_context)
{
m_context = &p_context;
}
constexpr auto continuation()
{
return m_continuation;
}
constexpr void continuation(std::coroutine_handle<> p_continuation)
{
m_continuation = p_continuation;
}
[[nodiscard]] constexpr auto frame_size() const
{
return m_frame_size;
}
constexpr std::coroutine_handle<> pop_active_coroutine()
{
m_context->active_handle(m_continuation);
return m_continuation;
}
protected:
// Storage for the coroutine result/error
std::coroutine_handle<> m_continuation{};
async_context* m_context{};
// NOLINTNEXTLINE(bugprone-throw-keyword-missing)
std::exception_ptr m_error{};
usize m_frame_size = 0;
};
template<typename T>
class async;
// Helper type for void
struct void_placeholder
{};
// Type selection for void vs non-void
template<typename T>
using void_to_placeholder_t =
std::conditional_t<std::is_void_v<T>, void_placeholder, T>;
template<typename T>
class task_promise_type : public task_promise_base
{
public:
using task_promise_base::task_promise_base; // Inherit constructors
using task_promise_base::operator new;
// Add regular delete operators for normal coroutine destruction
static constexpr void operator delete(void*) noexcept
{
}
static constexpr void operator delete(void*, std::size_t) noexcept
{
}
struct final_awaiter
{
constexpr bool await_ready() noexcept
{
return false;
}
template<typename U>
std::coroutine_handle<> await_suspend(
std::coroutine_handle<task_promise_type<U>> p_self) noexcept
{
// The coroutine is now suspended at the final-suspend point.
// Lookup its continuation in the promise and resume it symmetrically.
//
// Rather than return control back to the application, we continue the
// caller function allowing it to yield when it reaches another suspend
// point. The idea is that prior to this being called, we were executing
// code and thus, when we resume the caller, we are still running code.
// Lets continue to run as much code until we reach an actual suspend
// point.
return p_self.promise().pop_active_coroutine();
}
void await_resume() noexcept
{
}
};
final_awaiter final_suspend() noexcept
{
return {};
}
constexpr async<T> get_return_object() noexcept;
// For non-void return type
template<typename U = T>
void return_value(U&& p_value) noexcept
requires(not std::is_void_v<T>)
{
m_value = std::forward<U>(p_value);
}
auto result()
{
return m_value;
}
void_to_placeholder_t<T> m_value{};
};
template<>
class task_promise_type<void> : public task_promise_base
{
public:
using task_promise_base::task_promise_base; // Inherit constructors
using task_promise_base::operator new;
// using task_promise_base::operator delete;
task_promise_type();
constexpr void return_void() noexcept
{
}
constexpr async<void> get_return_object() noexcept;
// Add regular delete operators for normal coroutine destruction
static constexpr void operator delete(void*) noexcept
{
}
static constexpr void operator delete(void*, std::size_t) noexcept
{
}
struct final_awaiter
{
constexpr bool await_ready() noexcept
{
return false;
}
template<typename U>
std::coroutine_handle<> await_suspend(
std::coroutine_handle<task_promise_type<U>> p_self) noexcept
{
// The coroutine is now suspended at the final-suspend point.
// Lookup its continuation in the promise and resume it symmetrically.
//
// Rather than return control back to the application, we continue the
// caller function allowing it to yield when it reaches another suspend
// point. The idea is that prior to this being called, we were executing
// code and thus, when we resume the caller, we are still running code.
// Lets continue to run as much code until we reach an actual suspend
// point.
return p_self.promise().pop_active_coroutine();
}
void await_resume() noexcept
{
}
};
final_awaiter final_suspend() noexcept
{
return {};
}
};
template<typename T = void>
class async
{
public:
using promise_type = task_promise_type<T>;
friend promise_type;
void resume()
{
auto active = m_handle.promise().context().active_handle();
active.resume();
}
// Run synchronously and return result
T sync_result()
{
if (not m_handle) {
return T{};
}
while (not m_handle.done()) {
auto active = m_handle.promise().context().active_handle();
active.resume();
}
if constexpr (not std::is_void_v<T>) {
return m_handle.promise().result();
}
}
// Awaiter for when this task is awaited
struct awaiter
{
std::coroutine_handle<promise_type> m_handle;
explicit awaiter(std::coroutine_handle<promise_type> p_handle) noexcept
: m_handle(p_handle)
{
}
[[nodiscard]] constexpr bool await_ready() const noexcept
{
return not m_handle;
}
// Generic await_suspend for any promise type
template<typename Promise>
std::coroutine_handle<> await_suspend(
std::coroutine_handle<Promise> p_continuation) noexcept
{
m_handle.promise().continuation(p_continuation);
return m_handle;
}
T await_resume()
{
if constexpr (not std::is_void_v<T>) {
if (m_handle) {
return m_handle.promise().result();
}
}
}
};
[[nodiscard]] constexpr awaiter operator co_await() const noexcept
{
return awaiter{ m_handle };
}
async() noexcept = default;
async(async&& p_other) noexcept
: m_handle(std::exchange(p_other.m_handle, {}))
{
}
~async()
{
if (m_handle) {
void* const address = m_handle.address();
auto* const allocator = m_handle.promise().context().resource();
auto const frame_size = m_handle.promise().frame_size();
m_handle.destroy();
allocator->deallocate(address, frame_size);
}
}
async& operator=(async&& p_other) noexcept
{
if (this != &p_other) {
if (m_handle) {
m_handle.destroy();
}
m_handle = std::exchange(p_other.m_handle, {});
}
return *this;
}
auto handle()
{
return m_handle;
}
private:
explicit async(std::coroutine_handle<promise_type> p_handle)
: m_handle(p_handle)
{
m_handle.promise().continuation(std::noop_coroutine());
m_handle.promise().context().active_handle(m_handle);
}
std::coroutine_handle<promise_type> m_handle;
};
template<typename T>
constexpr async<T> task_promise_type<T>::get_return_object() noexcept
{
return async<T>{ std::coroutine_handle<task_promise_type<T>>::from_promise(
*this) };
}
constexpr async<void> task_promise_type<void>::get_return_object() noexcept
{
return async<void>{
std::coroutine_handle<task_promise_type<void>>::from_promise(*this)
};
}
class coroutine_stack_memory_resource : public std::pmr::memory_resource
{
public:
constexpr coroutine_stack_memory_resource(std::span<hal::byte> p_memory)
: m_memory(p_memory)
{
if (p_memory.data() == nullptr || p_memory.size() < 32) {
throw std::runtime_error(
"Coroutine stack memory invalid! Must be non-null and size > 32.");
}
}
private:
void* do_allocate(std::size_t p_bytes, std::size_t) override
{
auto* const new_stack_pointer = &m_memory[m_stack_pointer];
m_stack_pointer += p_bytes;
return new_stack_pointer;
}
void do_deallocate(void*, std::size_t p_bytes, std::size_t) override
{
m_stack_pointer -= p_bytes;
}
[[nodiscard]] bool do_is_equal(
std::pmr::memory_resource const& other) const noexcept override
{
return this == &other;
}
std::span<hal::byte> m_memory;
hal::usize m_stack_pointer = 0;
};
} // namespace hal
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment