Created
January 30, 2020 16:40
-
-
Save GaZaTu/00db67a99f0ea9609cd573f48d4ad308 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <exception> | |
#include <coroutine> | |
#include <atomic> | |
#include <functional> | |
#include <iostream> | |
template<typename T> | |
class future { | |
public: | |
class promise_type { | |
public: | |
class final_awaitable { | |
public: | |
bool await_ready() { | |
std::cout << "future<T>::promise_type::final_awaitable::await_ready" << std::endl; | |
return false; | |
} | |
template<typename PROMISE> | |
void await_suspend(std::coroutine_handle<PROMISE> coro) noexcept { | |
std::cout << "future<T>::promise_type::final_awaitable::await_suspend" << std::endl; | |
promise_type& promise = coro.promise(); | |
// Use 'release' memory semantics in case we finish before the | |
// awaiter can suspend so that the awaiting thread sees our | |
// writes to the resulting value. | |
// Use 'acquire' memory semantics in case the caller registered | |
// the continuation before we finished. Ensure we see their write | |
// to m_continuation. | |
if (promise._continuation_state.exchange(true, std::memory_order_acq_rel)) { | |
promise._continuation.resume(); | |
} | |
} | |
void await_resume() { | |
std::cout << "future<T>::promise_type::final_awaitable::await_resume" << std::endl; | |
} | |
}; | |
T _value = nullptr; | |
std::exception_ptr _exception; | |
bool _failed = false; | |
promise_type() noexcept : _continuation_state(false) { } | |
~promise_type() { | |
if (_failed) { | |
_exception.~exception_ptr(); | |
} else { | |
_value.~T(); | |
} | |
} | |
auto get_return_object() { | |
return coro_handle::from_promise(*this); | |
} | |
// custom | |
auto get_awaitable() { | |
return future(get_return_object()); | |
} | |
auto initial_suspend() noexcept { | |
return std::suspend_always(); | |
} | |
auto final_suspend() noexcept { | |
std::cout << "future<T>::promise_type::final_suspend" << std::endl; | |
return final_awaitable(); | |
} | |
void set_continuation(std::coroutine_handle<> continuation) noexcept { | |
_continuation = continuation; | |
} | |
void unhandled_exception() { | |
::new (static_cast<void*>(std::addressof(_exception))) std::exception_ptr(std::current_exception()); | |
_failed = true; | |
} | |
template< | |
typename VALUE, | |
typename = std::enable_if_t<std::is_convertible_v<VALUE&&, T>>> | |
void return_value(VALUE&& value) noexcept(std::is_nothrow_constructible_v<T, VALUE&&>) { | |
std::cout << "future<T>::promise_type::return_void" << std::endl; | |
::new (static_cast<void*>(std::addressof(_value))) T(std::forward<VALUE>(value)); | |
} | |
T& result() & { | |
std::cout << "future<T>::promise_type::result" << std::endl; | |
if (_failed) { | |
std::rethrow_exception(_exception); | |
} else { | |
return _value; | |
} | |
} | |
private: | |
std::coroutine_handle<> _continuation; | |
// Initially false. Set to true when either a continuation is registered | |
// or when the coroutine has run to completion. Whichever operation | |
// successfully transitions from false->true got there first. | |
std::atomic<bool> _continuation_state; | |
}; | |
using coro_handle = std::coroutine_handle<promise_type>; | |
future(coro_handle handle) : _handle(handle) { } | |
future(future&& other) : _handle(other._handle) { | |
other._handle = nullptr; | |
} | |
future(const future&) = delete; | |
future& operator=(const future&) = delete; | |
~future() { | |
_handle.destroy(); | |
} | |
bool done() { | |
return _handle.done(); | |
} | |
// bool resume() { | |
// if (!_handle.done()) | |
// _handle.resume(); | |
// return !_handle.done(); | |
// } | |
T& result() & { | |
std::cout << "future<T>::result" << std::endl; | |
return _handle.promise().result(); | |
} | |
private: | |
coro_handle _handle; | |
}; | |
template<> | |
class future<void> { | |
public: | |
class promise_type { | |
public: | |
class final_awaitable { | |
public: | |
bool await_ready() { | |
std::cout << "future<void>::promise_type::final_awaitable::await_ready" << std::endl; | |
return false; | |
} | |
template<typename PROMISE> | |
void await_suspend(std::coroutine_handle<PROMISE> coro) noexcept { | |
std::cout << "future<void>::promise_type::final_awaitable::await_suspend" << std::endl; | |
promise_type& promise = coro.promise(); | |
// Use 'release' memory semantics in case we finish before the | |
// awaiter can suspend so that the awaiting thread sees our | |
// writes to the resulting value. | |
// Use 'acquire' memory semantics in case the caller registered | |
// the continuation before we finished. Ensure we see their write | |
// to m_continuation. | |
if (promise._continuation_state.exchange(true, std::memory_order_acq_rel)) { | |
promise._continuation.resume(); | |
} | |
} | |
void await_resume() { | |
std::cout << "future<void>::promise_type::final_awaitable::await_resume" << std::endl; | |
} | |
}; | |
std::exception_ptr _exception; | |
bool _failed = false; | |
promise_type() noexcept : _continuation_state(false) {} | |
~promise_type() { | |
if (_failed) { | |
_exception.~exception_ptr(); | |
} | |
} | |
auto get_return_object() { | |
return coro_handle::from_promise(*this); | |
} | |
// custom | |
auto get_awaitable() { | |
return future(get_return_object()); | |
} | |
auto initial_suspend() noexcept { | |
return std::suspend_always(); | |
} | |
auto final_suspend() noexcept { | |
std::cout << "future<void>::promise_type::final_suspend" << std::endl; | |
return final_awaitable(); | |
} | |
bool try_set_continuation(std::coroutine_handle<> continuation) { | |
_continuation = continuation; | |
return !_continuation_state.exchange(true, std::memory_order_acq_rel); | |
} | |
void unhandled_exception() { | |
::new (static_cast<void*>(std::addressof(_exception))) std::exception_ptr(std::current_exception()); | |
_failed = true; | |
} | |
void return_void() { | |
std::cout << "future<void>::promise_type::return_void" << std::endl; | |
} | |
void result() { | |
std::cout << "future<void>::promise_type::result" << std::endl; | |
if (_failed) { | |
std::rethrow_exception(_exception); | |
} | |
} | |
private: | |
std::coroutine_handle<> _continuation; | |
// Initially false. Set to true when either a continuation is registered | |
// or when the coroutine has run to completion. Whichever operation | |
// successfully transitions from false->true got there first. | |
std::atomic<bool> _continuation_state; | |
}; | |
using coro_handle = std::coroutine_handle<promise_type>; | |
class awaitable_base { | |
public: | |
coro_handle _coroutine; | |
awaitable_base(coro_handle coroutine) noexcept : _coroutine(coroutine) {} | |
bool await_ready() const noexcept { | |
return !_coroutine || _coroutine.done(); | |
} | |
bool await_suspend(std::coroutine_handle<> awaitingCoroutine) noexcept { | |
// NOTE: We are using the bool-returning version of await_suspend() here | |
// to work around a potential stack-overflow issue if a coroutine | |
// awaits many synchronously-completing tasks in a loop. | |
// | |
// We first start the task by calling resume() and then conditionally | |
// attach the continuation if it has not already completed. This allows us | |
// to immediately resume the awaiting coroutine without increasing | |
// the stack depth, avoiding the stack-overflow problem. However, it has | |
// the down-side of requiring a std::atomic to arbitrate the race between | |
// the coroutine potentially completing on another thread concurrently | |
// with registering the continuation on this thread. | |
// | |
// We can eliminate the use of the std::atomic once we have access to | |
// coroutine_handle-returning await_suspend() on both MSVC and Clang | |
// as this will provide ability to suspend the awaiting coroutine and | |
// resume another coroutine with a guaranteed tail-call to resume(). | |
_coroutine.resume(); | |
return _coroutine.promise().try_set_continuation(awaitingCoroutine); | |
} | |
}; | |
future(coro_handle handle) : _handle(handle) { } | |
future(future&& other) : _handle(other._handle) { | |
other._handle = nullptr; | |
} | |
future(const future&) = delete; | |
future& operator=(const future&) = delete; | |
~future() { | |
_handle.destroy(); | |
} | |
bool done() { | |
return _handle.done(); | |
} | |
void result() { | |
std::cout << "future<void>::result" << std::endl; | |
_handle.promise().result(); | |
} | |
auto operator co_await() const & noexcept { | |
class awaitable : public awaitable_base { | |
public: | |
using awaitable_base::awaitable_base; | |
decltype(auto) await_resume() { | |
if (!_coroutine) { | |
// throw broken_promise{}; | |
} | |
return _coroutine.promise().result(); | |
} | |
}; | |
return awaitable(_handle); | |
} | |
auto operator co_await() const && noexcept { | |
class awaitable : public awaitable_base { | |
public: | |
using awaitable_base::awaitable_base; | |
decltype(auto) await_resume() { | |
if (!_coroutine) { | |
// throw broken_promise{}; | |
} | |
return std::move(_coroutine.promise()).result(); | |
} | |
}; | |
return awaitable(_handle); | |
} | |
private: | |
coro_handle _handle; | |
}; | |
template<typename T> | |
class callback_future { | |
public: | |
class promise_type { | |
public: | |
class completion_awaitable { | |
public: | |
bool await_ready() const noexcept { | |
return false; | |
} | |
void await_suspend(std::coroutine_handle<promise_type> coroutine) const noexcept { | |
coroutine.promise()._callback(); | |
} | |
void await_resume() noexcept {} | |
}; | |
promise_type() noexcept { } | |
void start(std::function<void()> callback) { | |
_callback = callback; | |
coro_handle::from_promise(*this).resume(); | |
} | |
auto get_return_object() { | |
return coro_handle::from_promise(*this); | |
} | |
// custom | |
auto get_awaitable() { | |
return callback_future(get_return_object()); | |
} | |
auto initial_suspend() noexcept { | |
return std::suspend_always(); | |
} | |
auto final_suspend() noexcept { | |
return completion_awaitable(); | |
} | |
auto yield_value(T&& result) noexcept { | |
_result = std::addressof(result); | |
return final_suspend(); | |
} | |
void return_void() noexcept { | |
// The coroutine should have either yielded a value or thrown | |
// an exception in which case it should have bypassed return_void(). | |
// assert(false); | |
} | |
void unhandled_exception() { | |
_exception = std::current_exception(); | |
} | |
T&& result() { | |
if (_exception) { | |
std::rethrow_exception(_exception); | |
} else { | |
return static_cast<T&&>(*_result); | |
} | |
} | |
private: | |
std::remove_reference_t<T>* _result; | |
std::exception_ptr _exception; | |
std::function<void()> _callback; | |
}; | |
using coro_handle = std::coroutine_handle<promise_type>; | |
callback_future(coro_handle handle) : _handle(handle) { } | |
callback_future(callback_future&& other) : _handle(other._handle) { | |
other._handle = nullptr; | |
} | |
callback_future(const callback_future&) = delete; | |
callback_future& operator=(const callback_future&) = delete; | |
~callback_future() { | |
_handle.destroy(); | |
} | |
void start(std::function<void()> callback) noexcept { | |
_handle.promise().start(callback); | |
} | |
decltype(auto) result() { | |
return _handle.promise().result(); | |
} | |
private: | |
coro_handle _handle; | |
}; | |
template<> | |
class callback_future<void> { | |
public: | |
class promise_type { | |
public: | |
class completion_awaitable { | |
public: | |
bool await_ready() const noexcept { | |
return false; | |
} | |
void await_suspend(std::coroutine_handle<promise_type> coroutine) const noexcept { | |
coroutine.promise()._callback(); | |
} | |
void await_resume() noexcept {} | |
}; | |
promise_type() noexcept { } | |
void start(std::function<void()> callback) { | |
_callback = callback; | |
coro_handle::from_promise(*this).resume(); | |
} | |
auto get_return_object() { | |
return coro_handle::from_promise(*this); | |
} | |
// custom | |
auto get_awaitable() { | |
return callback_future(get_return_object()); | |
} | |
auto initial_suspend() noexcept { | |
return std::suspend_always(); | |
} | |
auto final_suspend() noexcept { | |
return completion_awaitable(); | |
} | |
void return_void() noexcept { } | |
void unhandled_exception() { | |
_exception = std::current_exception(); | |
} | |
void result() { | |
if (_exception) { | |
std::rethrow_exception(_exception); | |
} | |
} | |
private: | |
std::exception_ptr _exception; | |
std::function<void()> _callback; | |
}; | |
using coro_handle = std::coroutine_handle<promise_type>; | |
callback_future(coro_handle handle) : _handle(handle) { } | |
callback_future(callback_future&& other) : _handle(std::move(other._handle)) { | |
other._handle = nullptr; | |
} | |
callback_future& operator=(callback_future&& other) { | |
_handle = std::move(other._handle); | |
other._handle = nullptr; | |
return *this; | |
} | |
callback_future(const callback_future&) = delete; | |
callback_future& operator=(const callback_future&) = delete; | |
~callback_future() { | |
if (_handle) { | |
_handle.destroy(); | |
} | |
} | |
void start(std::function<void()> callback) noexcept { | |
std::cout << "callback_future<void>::start" << std::endl; | |
_handle.promise().start(callback); | |
} | |
decltype(auto) result() { | |
return _handle.promise().result(); | |
} | |
private: | |
coro_handle _handle; | |
}; | |
template< | |
typename AWAITABLE, | |
typename RESULT = void, | |
std::enable_if_t<std::is_void_v<RESULT>, int> = 0> | |
callback_future<void> make_callback_future(AWAITABLE&& awaitable) { | |
co_await std::forward<AWAITABLE>(awaitable); | |
} | |
template<typename AWAITABLE> | |
void callback_wait(AWAITABLE&& awaitable, std::function<void()> callback) { | |
std::cout << "callback_wait1" << std::endl; | |
auto task = reinterpret_cast<callback_future<void>*>(operator new(sizeof(callback_future<void>))); | |
std::cout << "callback_wait2" << std::endl; | |
*task = std::move(make_callback_future(std::forward<AWAITABLE>(awaitable))); | |
std::cout << "callback_wait3" << std::endl; | |
task->start([task, callback]() { | |
callback(); | |
task->result(); | |
delete task; | |
}); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment