Last active
November 12, 2024 21:00
-
-
Save ericniebler/dd3ed2a7661162909dd523da35eee7ed to your computer and use it in GitHub Desktop.
A toy implementation of P2300, the std::execution proposal, for teaching purposes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2022-2024 NVIDIA Corporation | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
// This is a toy implementation of the core parts of the C++ std::execution | |
// proposal (aka, Senders, http://wg21.link/P2300). It is intended as a | |
// learning tool only. THIS CODE IS NOT SUITABLE FOR ANY USE. | |
#include <atomic> | |
#include <condition_variable> | |
#include <concepts> | |
#include <cstdio> | |
#include <exception> | |
#include <functional> | |
#include <mutex> | |
#include <optional> | |
#include <sstream> | |
#include <stop_token> | |
#include <thread> | |
#include <utility> | |
// Some utility code | |
/////////////////////////////////////////// | |
std::string get_thread_id() { | |
std::stringstream sout; | |
sout.imbue(std::locale::classic()); | |
sout << "0x" << std::hex << std::this_thread::get_id(); | |
return sout.str(); | |
} | |
struct immovable { | |
immovable() = default; | |
immovable(immovable&&) = delete; | |
}; | |
struct none {}; | |
struct ignore { | |
ignore(auto&&...) noexcept {} | |
void operator=(auto&&) {} | |
}; | |
// Some tuple utilities used by when_all | |
template <class CvTuple> | |
using _tuple_indices = std::make_index_sequence<std::tuple_size_v<std::remove_cvref_t<CvTuple>>>; | |
template <bool AddReference, class Element> | |
using _maybe_add_reference = std::conditional_t<AddReference, Element&, Element>; | |
template <size_t I, class CvTuple> | |
using _tuple_element_t = | |
_maybe_add_reference<std::is_reference_v<CvTuple>, | |
std::tuple_element_t<I, std::remove_reference_t<CvTuple>>>; | |
// Map error types to an exception_ptr | |
inline std::exception_ptr as_eptr(auto eptr) noexcept { | |
return std::make_exception_ptr(eptr); | |
} | |
inline std::exception_ptr as_eptr(std::exception_ptr eptr) noexcept { | |
return eptr; | |
} | |
inline std::exception_ptr as_eptr(std::error_code ec) noexcept { | |
return std::make_exception_ptr(std::system_error(ec)); | |
} | |
// Nothrow emplace into an optional | |
template <class T, class U = T> | |
inline std::exception_ptr try_emplace(std::optional<T>& opt, U&& value) noexcept try { | |
opt.emplace(std::forward<U>(value)); | |
return std::exception_ptr(); | |
} catch(...) { | |
return std::current_exception(); | |
} | |
// In this toy implementation, a sender can only complete with a single value. | |
template <class CvSndr> | |
using sender_result_t = typename std::remove_reference_t<CvSndr>::result_t; | |
template <class CvSndr, class Rcvr> | |
using connect_result_t = decltype(std::declval<CvSndr>().connect(std::declval<Rcvr>())); | |
// Most operation states are immovable, so putting them in optionals and tuples | |
// is tricky. We use connect_emplace together with C++17's guaranteed copy | |
// elision to emplace the opstates in the containers via a conversion operator. | |
template <class CvSndr, class Rcvr> | |
struct connect_emplace { | |
CvSndr&& sndr_; | |
Rcvr rcvr_; | |
operator connect_result_t<CvSndr, Rcvr> () { | |
return std::forward<CvSndr>(sndr_).connect(rcvr_); | |
} | |
}; | |
template <class CvSndr, class Rcvr> | |
connect_emplace(CvSndr&&, Rcvr) -> connect_emplace<CvSndr, Rcvr>; | |
// The three ways an operation can complete: | |
enum class disposition { value, error, stopped }; | |
/////////////////////////////////////////// | |
// environment queries | |
/////////////////////////////////////////// | |
constexpr struct get_stop_token_t { | |
auto operator()(const auto& env) const noexcept -> decltype(env.query(*this)) { | |
return env.query(*this); | |
} | |
auto operator()(ignore) const noexcept { | |
return std::stop_token(); | |
} | |
} get_stop_token{}; | |
constexpr struct get_scheduler_t { | |
auto operator()(const auto& env) const noexcept -> decltype(env.query(*this)) { | |
return env.query(*this); | |
} | |
} get_scheduler{}; | |
////////////////////////////////////////////////////////////////// | |
// just(Value), just_error(Error), just_stopped() sender factories | |
////////////////////////////////////////////////////////////////// | |
template <class Rcvr, disposition Disp, class T> | |
struct just_operation { | |
void start() noexcept requires (Disp == disposition::value) { | |
rcvr_.set_value(std::move(value_)); | |
} | |
void start() noexcept requires (Disp == disposition::error) { | |
rcvr_.set_error(as_eptr(std::move(value_))); | |
} | |
void start() noexcept requires (Disp == disposition::stopped) { | |
rcvr_.set_stopped(); | |
} | |
Rcvr rcvr_; | |
T value_; | |
}; | |
template <disposition Disp, class T> | |
struct just_sender { | |
using result_t = T; | |
template <class Rcvr> | |
auto connect(Rcvr rcvr) { | |
return just_operation<Rcvr, Disp, T>{rcvr, value_}; | |
} | |
T value_; | |
}; | |
template <class T> | |
inline auto just(T t) { | |
return just_sender<disposition::value, T>{t}; | |
} | |
template <class T> | |
inline auto just_error(T t) { | |
return just_sender<disposition::error, T>{t}; | |
} | |
inline auto just_stopped() { | |
return just_sender<disposition::stopped, none>{}; | |
} | |
/////////////////////////////////////////// | |
// then(Sender, Function) sender adaptor | |
/////////////////////////////////////////// | |
template <class Fun, class Rcvr> | |
struct then_state { | |
Rcvr rcvr_; | |
Fun fun_; | |
}; | |
template <class Fun, class Rcvr> | |
struct then_receiver { | |
template <class U> | |
void set_value(U&& val) noexcept try { | |
state_.rcvr_.set_value(state_.fun_(std::forward<U>(val))); | |
} catch(...) { | |
state_.rcvr_.set_error(std::current_exception()); | |
} | |
void set_error(std::exception_ptr eptr) noexcept { | |
state_.rcvr_.set_error(std::move(eptr)); | |
} | |
void set_stopped() noexcept { | |
state_.rcvr_.set_stopped(); | |
} | |
decltype(auto) get_env() const noexcept { | |
return state_.rcvr_.get_env(); | |
} | |
then_state<Fun, Rcvr>& state_; | |
}; | |
template <class Rcvr, class CvSndr, class Fun> | |
struct then_operation : private immovable { | |
then_operation(CvSndr&& sndr, Fun fn, Rcvr rcvr) | |
: state_{rcvr, fn} | |
, child_operation_(std::forward<CvSndr>(sndr).connect(then_receiver{state_})) | |
{} | |
void start() noexcept { | |
child_operation_.start(); | |
} | |
then_state<Fun, Rcvr> state_; | |
connect_result_t<CvSndr, then_receiver<Fun, Rcvr>> child_operation_; | |
}; | |
template <class Rcvr, class CvSndr, class Fun> | |
then_operation(CvSndr&&, Fun, Rcvr) -> then_operation<Rcvr, CvSndr, Fun>; | |
template <class Sndr, class Fun> | |
struct then_sender { | |
using result_t = std::invoke_result_t<Fun, sender_result_t<Sndr>>; | |
template <class Rcvr> | |
auto connect(Rcvr rcvr) { | |
return then_operation{sndr_, fun_, rcvr}; | |
} | |
Sndr sndr_; | |
Fun fun_; | |
}; | |
template <class Sndr, class Fun> | |
auto then(Sndr sndr, Fun fun) { | |
return then_sender<Sndr, Fun>{std::move(sndr), fun}; | |
} | |
/////////////////////////////////////////// | |
// when_all(Senders...) sender adaptor | |
/////////////////////////////////////////// | |
// Used to broadcast a stop request from when_all's parent | |
// to when_all's children: | |
struct broadcast_stop_source : private std::stop_source { | |
// Called with the predecessor's stop token (obtained from the | |
// receiver connected to the when_all sender). | |
void register_with(std::stop_token stok) noexcept { | |
on_stopped_.emplace(stok, on_stopped_fn{this}); | |
} | |
void unregister() noexcept { | |
on_stopped_.reset(); | |
} | |
using std::stop_source::request_stop; | |
using std::stop_source::stop_requested; | |
using std::stop_source::get_token; | |
private: | |
struct on_stopped_fn { | |
std::stop_source* self_; | |
void operator()() const noexcept { | |
self_->request_stop(); | |
} | |
}; | |
std::optional<std::stop_callback<on_stopped_fn>> on_stopped_{}; | |
}; | |
// This struct has storage for all the child senders' values and errors. | |
// It also holds metadata used to track which child operations have | |
// completed and how. Additionally, it stores a broadcast_stop_source | |
// to fan out a stop request from the parent to all the children. | |
template <class Rcvr, class... Vals> | |
struct when_all_state { | |
static auto make_result(std::optional<Vals>&... vals) { | |
return std::tuple(*vals...); | |
} | |
void complete() noexcept try { | |
if (0 == --count_) { | |
// unregister the stop callback: | |
ssource_.unregister(); | |
// Complete the parent's receiver according to the disposition. | |
switch(disp_) { | |
case disposition::value: return rcvr_.set_value(std::apply(make_result, vals_)); | |
case disposition::error: return rcvr_.set_error(eptr_); | |
case disposition::stopped: return rcvr_.set_stopped(); | |
} | |
} | |
} catch(...) { | |
rcvr_.set_error(std::current_exception()); | |
} | |
// The disp_ member indicates how the when_all operation should complete. | |
// If any of the child senders complete in error, the when_all operation will | |
// complete with the first such error. Otherwise, if any of the child senders | |
// have stopped early, the when_all operation will complete with "stopped". | |
// Otherwise, all the child operations have completed successfully, so the | |
// when_all operation will complete successfully with a tuple of the child | |
// senders' values. | |
Rcvr rcvr_; | |
broadcast_stop_source ssource_{}; | |
std::atomic<int> count_{sizeof...(Vals)}; // zero when all child operations have completed. | |
std::atomic<disposition> disp_{disposition::value}; | |
std::tuple<std::optional<Vals>...> vals_{}; | |
std::exception_ptr eptr_{}; | |
}; | |
// This is environment for the receivers used to connect when_all's child senders. | |
// It exposes the when_all sender's stop_token, used to communicate stop requests | |
// to the child operations. | |
template <class Rcvr, class... Vals> | |
struct when_all_environment { | |
when_all_state<Rcvr, Vals...>* state_; | |
// respond to the get_stop_token query with the stop token of when_all's operation. | |
auto query(get_stop_token_t) const noexcept -> std::stop_token { | |
return state_->ssource_.get_token(); | |
} | |
// forward all other queries to the parent receiver's environment | |
auto query(auto tag) const noexcept -> decltype(tag(state_->rcvr_.get_env())) { | |
return tag(state_->rcvr_.get_env()); | |
} | |
}; | |
template <std::size_t I, class Rcvr, class... Vals> | |
struct when_all_receiver { | |
template <class U> | |
void set_value(U&& val) noexcept { | |
// Emplace the result of the I-th child operation | |
if (auto eptr = try_emplace(std::get<I>(state_.vals_), std::forward<U>(val))) | |
return set_error(std::move(eptr)); | |
state_.complete(); | |
} | |
void set_error(std::exception_ptr eptr) noexcept { | |
// If there hasn't already been an error in another child... | |
if (disposition::error != state_.disp_.exchange(disposition::error)) { | |
state_.eptr_ = std::move(eptr); // save the error, and | |
state_.ssource_.request_stop(); // ask the other operations to stop early. | |
} | |
state_.complete(); | |
} | |
void set_stopped() noexcept { | |
disposition expected = disposition::value; | |
// Set the disposition to stopped if it isn't already stopped or error: | |
if (state_.disp_.compare_exchange_strong(expected, disposition::stopped)) | |
state_.ssource_.request_stop(); // Ask the other children to stop early | |
state_.complete(); | |
} | |
auto get_env() const noexcept { | |
return when_all_environment<Rcvr, Vals...>{&state_}; | |
} | |
when_all_state<Rcvr, Vals...>& state_; | |
}; | |
template <class Rcvr, class CvSndrs, class Is = _tuple_indices<CvSndrs>> | |
struct when_all_operation; | |
template <class Rcvr, class CvSndrs, std::size_t... Is> | |
struct when_all_operation<Rcvr, CvSndrs, std::index_sequence<Is...>> : private immovable { | |
// Connect all the child senders with receivers that perform bookkeeping for the when_all operation | |
when_all_operation(CvSndrs&& sndrs, Rcvr rcvr) | |
: state_{rcvr} | |
, child_operations_(connect_emplace{std::get<Is>(std::forward<CvSndrs>(sndrs)), child_receiver<Is>{state_}}...) | |
{} | |
void start() noexcept { | |
// Register a stop callback with the receiver's stop token | |
state_.ssource_.register_with(get_stop_token(state_.rcvr_.get_env())); | |
// If stop was already requested on the receiver's stop token, then the callback | |
// has just been executed and state_.ssource_.request_stop() has been called. | |
if (state_.ssource_.stop_requested()) { | |
// No need to start the child operations. Send a stopped signal. | |
state_.rcvr_.set_stopped(); | |
} else { | |
// start all child operations in order: | |
std::apply([](auto&... ops) { (ops.start(), ...); }, child_operations_); | |
} | |
} | |
private: | |
template <std::size_t I> | |
using child_sender = _tuple_element_t<I, CvSndrs>; | |
template <std::size_t I> | |
using child_receiver = when_all_receiver<I, Rcvr, sender_result_t<child_sender<Is>>...>; | |
// The state_ member is where values and errors get written, and where the metadata | |
// lives for the when_all operation. | |
when_all_state<Rcvr, sender_result_t<child_sender<Is>>...> state_; | |
// This tuple holds the operation states for all the child senders. | |
std::tuple<connect_result_t<child_sender<Is>, child_receiver<Is>>...> child_operations_; | |
}; | |
template <class CvSndrs, class Rcvr> | |
when_all_operation(CvSndrs&&, Rcvr) -> when_all_operation<Rcvr, CvSndrs>; | |
// The when_all sender just stores the child senders and publishes | |
// the sender's completion result. | |
template <class... Sndrs> | |
struct when_all_sender { | |
using result_t = std::tuple<sender_result_t<Sndrs>...>; | |
template <class Rcvr> | |
auto connect(Rcvr rcvr) { | |
return when_all_operation{sndrs_, rcvr}; | |
} | |
std::tuple<Sndrs...> sndrs_; | |
}; | |
template <class... Sndrs> | |
auto when_all(Sndrs... sndrs) { | |
return when_all_sender<Sndrs...>{{sndrs...}}; | |
} | |
/////////////////////////////////////////// | |
// run_loop execution context | |
/////////////////////////////////////////// | |
class run_loop : immovable { | |
struct task : private immovable { | |
task* next_ = this; | |
virtual void execute() {} | |
}; | |
template <class Rcvr> | |
struct operation : task { | |
Rcvr rcvr_; | |
run_loop* loop_; | |
operation(Rcvr rcvr, run_loop* loop) | |
: rcvr_(rcvr), loop_(loop) {} | |
void execute() override final { | |
rcvr_.set_value(none{}); | |
} | |
void start() noexcept { | |
// if stop has been requested, don't even enqueue this work | |
if (get_stop_token(rcvr_.get_env()).stop_requested()) | |
rcvr_.set_stopped(); // complete immediately | |
else | |
loop_->push_back_(this); // enqueue the work for execution | |
} | |
}; | |
task head_; | |
task* tail_ = &head_; | |
bool finish_ = false; | |
std::mutex mtx_; | |
std::condition_variable cv_; | |
void push_back_(task* op) { | |
std::unique_lock lk(mtx_); | |
op->next_ = &head_; | |
tail_ = tail_->next_= op; | |
cv_.notify_one(); | |
} | |
task* pop_front_() { | |
std::unique_lock lk(mtx_); | |
if (tail_ == head_.next_) | |
tail_ = &head_; | |
cv_.wait(lk, [this]{ return head_.next_ != &head_ || finish_; }); | |
if (head_.next_ == &head_) | |
return nullptr; | |
return std::exchange(head_.next_, head_.next_->next_); | |
} | |
struct sender { | |
using result_t = none; | |
run_loop* loop_; | |
template <class Rcvr> | |
operation<Rcvr> connect(Rcvr rcvr) { | |
return {rcvr, loop_}; | |
} | |
}; | |
struct scheduler { | |
run_loop* loop_; | |
bool operator==(scheduler const&) const = default; | |
sender schedule() { | |
return {loop_}; | |
} | |
}; | |
public: | |
void run() { | |
while (auto* op = pop_front_()) | |
op->execute(); | |
} | |
scheduler get_scheduler() { | |
return {this}; | |
} | |
void finish() { | |
std::unique_lock lk(mtx_); | |
finish_ = true; | |
cv_.notify_all(); | |
} | |
}; | |
/////////////////////////////////////////// | |
// sync_wait() sender consumer | |
/////////////////////////////////////////// | |
struct sync_wait_state { | |
run_loop loop; | |
std::exception_ptr eptr; | |
}; | |
struct sync_wait_environment { | |
// sync_wait makes a run_loop scheduler available to child | |
// operations for scheduling work on the waiting thread. | |
auto query(get_scheduler_t) const noexcept { | |
return state_->loop.get_scheduler(); | |
} | |
sync_wait_state* state_; | |
}; | |
template <class T> | |
struct sync_wait_receiver { | |
sync_wait_state& state_; | |
std::optional<T>& value_; | |
void set_value(T val) noexcept { | |
if (auto eptr = try_emplace(value_, val)) | |
return set_error(eptr); | |
state_.loop.finish(); // tell the run_loop to wind down | |
} | |
void set_error(std::exception_ptr eptr) noexcept { | |
state_.eptr = eptr; | |
state_.loop.finish(); | |
} | |
void set_stopped() noexcept { | |
state_.loop.finish(); | |
} | |
sync_wait_environment get_env() const noexcept { | |
return {&state_}; | |
} | |
}; | |
template <class Sndr, class T = sender_result_t<Sndr>> | |
std::optional<T> sync_wait(Sndr snd) { | |
sync_wait_state state; | |
std::optional<T> value; | |
auto op = snd.connect(sync_wait_receiver<T>{state, value}); | |
op.start(); // start the async operation | |
state.loop.run(); // drive the run_loop until completion | |
if (state.eptr) | |
std::rethrow_exception(state.eptr); | |
return value; | |
} | |
/////////////////////////////////////////// | |
// thread_context execution context | |
/////////////////////////////////////////// | |
class thread_context : immovable { | |
run_loop loop_; | |
std::thread thread_; | |
public: | |
thread_context() | |
: thread_([this]{ | |
std::printf("worker thread: %s\n", get_thread_id().c_str()); | |
loop_.run(); | |
}) | |
{} | |
void finish() { | |
loop_.finish(); | |
} | |
void join() { | |
thread_.join(); | |
} | |
auto get_scheduler() { | |
return loop_.get_scheduler(); | |
} | |
}; | |
// // | |
// // start test code | |
// // | |
// int main() { | |
// thread_context worker; | |
// std::printf("main thread: %s\n", get_thread_id().c_str()); | |
// auto log = [](auto val) { | |
// std::printf("Running task on thread: %s\n", get_thread_id().c_str()); | |
// return val; | |
// }; | |
// auto task1 = worker.get_scheduler().schedule(); | |
// auto task2 = then(task1, [](auto) { return 42; }); | |
// auto task3 = then(task2, log); | |
// auto task4 = just(42); | |
// auto task5 = then(task4, log); | |
// auto [i,j] = sync_wait(when_all(task3, task5)).value(); | |
// std::printf("result: {%d, %d}\n", i, j); | |
// worker.finish(); | |
// worker.join(); | |
// } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment