Skip to content

Instantly share code, notes, and snippets.

@ericniebler
Last active November 12, 2024 21:00
Show Gist options
  • Save ericniebler/dd3ed2a7661162909dd523da35eee7ed to your computer and use it in GitHub Desktop.
Save ericniebler/dd3ed2a7661162909dd523da35eee7ed to your computer and use it in GitHub Desktop.
A toy implementation of P2300, the std::execution proposal, for teaching purposes
/*
* Copyright 2022-2024 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This is a toy implementation of the core parts of the C++ std::execution
// proposal (aka, Senders, http://wg21.link/P2300). It is intended as a
// learning tool only. THIS CODE IS NOT SUITABLE FOR ANY USE.
#include <atomic>
#include <condition_variable>
#include <concepts>
#include <cstdio>
#include <exception>
#include <functional>
#include <mutex>
#include <optional>
#include <sstream>
#include <stop_token>
#include <thread>
#include <utility>
// Some utility code
///////////////////////////////////////////
std::string get_thread_id() {
std::stringstream sout;
sout.imbue(std::locale::classic());
sout << "0x" << std::hex << std::this_thread::get_id();
return sout.str();
}
struct immovable {
immovable() = default;
immovable(immovable&&) = delete;
};
struct none {};
struct ignore {
ignore(auto&&...) noexcept {}
void operator=(auto&&) {}
};
// Some tuple utilities used by when_all
template <class CvTuple>
using _tuple_indices = std::make_index_sequence<std::tuple_size_v<std::remove_cvref_t<CvTuple>>>;
template <bool AddReference, class Element>
using _maybe_add_reference = std::conditional_t<AddReference, Element&, Element>;
template <size_t I, class CvTuple>
using _tuple_element_t =
_maybe_add_reference<std::is_reference_v<CvTuple>,
std::tuple_element_t<I, std::remove_reference_t<CvTuple>>>;
// Map error types to an exception_ptr
inline std::exception_ptr as_eptr(auto eptr) noexcept {
return std::make_exception_ptr(eptr);
}
inline std::exception_ptr as_eptr(std::exception_ptr eptr) noexcept {
return eptr;
}
inline std::exception_ptr as_eptr(std::error_code ec) noexcept {
return std::make_exception_ptr(std::system_error(ec));
}
// Nothrow emplace into an optional
template <class T, class U = T>
inline std::exception_ptr try_emplace(std::optional<T>& opt, U&& value) noexcept try {
opt.emplace(std::forward<U>(value));
return std::exception_ptr();
} catch(...) {
return std::current_exception();
}
// In this toy implementation, a sender can only complete with a single value.
template <class CvSndr>
using sender_result_t = typename std::remove_reference_t<CvSndr>::result_t;
template <class CvSndr, class Rcvr>
using connect_result_t = decltype(std::declval<CvSndr>().connect(std::declval<Rcvr>()));
// Most operation states are immovable, so putting them in optionals and tuples
// is tricky. We use connect_emplace together with C++17's guaranteed copy
// elision to emplace the opstates in the containers via a conversion operator.
template <class CvSndr, class Rcvr>
struct connect_emplace {
CvSndr&& sndr_;
Rcvr rcvr_;
operator connect_result_t<CvSndr, Rcvr> () {
return std::forward<CvSndr>(sndr_).connect(rcvr_);
}
};
template <class CvSndr, class Rcvr>
connect_emplace(CvSndr&&, Rcvr) -> connect_emplace<CvSndr, Rcvr>;
// The three ways an operation can complete:
enum class disposition { value, error, stopped };
///////////////////////////////////////////
// environment queries
///////////////////////////////////////////
constexpr struct get_stop_token_t {
auto operator()(const auto& env) const noexcept -> decltype(env.query(*this)) {
return env.query(*this);
}
auto operator()(ignore) const noexcept {
return std::stop_token();
}
} get_stop_token{};
constexpr struct get_scheduler_t {
auto operator()(const auto& env) const noexcept -> decltype(env.query(*this)) {
return env.query(*this);
}
} get_scheduler{};
//////////////////////////////////////////////////////////////////
// just(Value), just_error(Error), just_stopped() sender factories
//////////////////////////////////////////////////////////////////
template <class Rcvr, disposition Disp, class T>
struct just_operation {
void start() noexcept requires (Disp == disposition::value) {
rcvr_.set_value(std::move(value_));
}
void start() noexcept requires (Disp == disposition::error) {
rcvr_.set_error(as_eptr(std::move(value_)));
}
void start() noexcept requires (Disp == disposition::stopped) {
rcvr_.set_stopped();
}
Rcvr rcvr_;
T value_;
};
template <disposition Disp, class T>
struct just_sender {
using result_t = T;
template <class Rcvr>
auto connect(Rcvr rcvr) {
return just_operation<Rcvr, Disp, T>{rcvr, value_};
}
T value_;
};
template <class T>
inline auto just(T t) {
return just_sender<disposition::value, T>{t};
}
template <class T>
inline auto just_error(T t) {
return just_sender<disposition::error, T>{t};
}
inline auto just_stopped() {
return just_sender<disposition::stopped, none>{};
}
///////////////////////////////////////////
// then(Sender, Function) sender adaptor
///////////////////////////////////////////
template <class Fun, class Rcvr>
struct then_state {
Rcvr rcvr_;
Fun fun_;
};
template <class Fun, class Rcvr>
struct then_receiver {
template <class U>
void set_value(U&& val) noexcept try {
state_.rcvr_.set_value(state_.fun_(std::forward<U>(val)));
} catch(...) {
state_.rcvr_.set_error(std::current_exception());
}
void set_error(std::exception_ptr eptr) noexcept {
state_.rcvr_.set_error(std::move(eptr));
}
void set_stopped() noexcept {
state_.rcvr_.set_stopped();
}
decltype(auto) get_env() const noexcept {
return state_.rcvr_.get_env();
}
then_state<Fun, Rcvr>& state_;
};
template <class Rcvr, class CvSndr, class Fun>
struct then_operation : private immovable {
then_operation(CvSndr&& sndr, Fun fn, Rcvr rcvr)
: state_{rcvr, fn}
, child_operation_(std::forward<CvSndr>(sndr).connect(then_receiver{state_}))
{}
void start() noexcept {
child_operation_.start();
}
then_state<Fun, Rcvr> state_;
connect_result_t<CvSndr, then_receiver<Fun, Rcvr>> child_operation_;
};
template <class Rcvr, class CvSndr, class Fun>
then_operation(CvSndr&&, Fun, Rcvr) -> then_operation<Rcvr, CvSndr, Fun>;
template <class Sndr, class Fun>
struct then_sender {
using result_t = std::invoke_result_t<Fun, sender_result_t<Sndr>>;
template <class Rcvr>
auto connect(Rcvr rcvr) {
return then_operation{sndr_, fun_, rcvr};
}
Sndr sndr_;
Fun fun_;
};
template <class Sndr, class Fun>
auto then(Sndr sndr, Fun fun) {
return then_sender<Sndr, Fun>{std::move(sndr), fun};
}
///////////////////////////////////////////
// when_all(Senders...) sender adaptor
///////////////////////////////////////////
// Used to broadcast a stop request from when_all's parent
// to when_all's children:
struct broadcast_stop_source : private std::stop_source {
// Called with the predecessor's stop token (obtained from the
// receiver connected to the when_all sender).
void register_with(std::stop_token stok) noexcept {
on_stopped_.emplace(stok, on_stopped_fn{this});
}
void unregister() noexcept {
on_stopped_.reset();
}
using std::stop_source::request_stop;
using std::stop_source::stop_requested;
using std::stop_source::get_token;
private:
struct on_stopped_fn {
std::stop_source* self_;
void operator()() const noexcept {
self_->request_stop();
}
};
std::optional<std::stop_callback<on_stopped_fn>> on_stopped_{};
};
// This struct has storage for all the child senders' values and errors.
// It also holds metadata used to track which child operations have
// completed and how. Additionally, it stores a broadcast_stop_source
// to fan out a stop request from the parent to all the children.
template <class Rcvr, class... Vals>
struct when_all_state {
static auto make_result(std::optional<Vals>&... vals) {
return std::tuple(*vals...);
}
void complete() noexcept try {
if (0 == --count_) {
// unregister the stop callback:
ssource_.unregister();
// Complete the parent's receiver according to the disposition.
switch(disp_) {
case disposition::value: return rcvr_.set_value(std::apply(make_result, vals_));
case disposition::error: return rcvr_.set_error(eptr_);
case disposition::stopped: return rcvr_.set_stopped();
}
}
} catch(...) {
rcvr_.set_error(std::current_exception());
}
// The disp_ member indicates how the when_all operation should complete.
// If any of the child senders complete in error, the when_all operation will
// complete with the first such error. Otherwise, if any of the child senders
// have stopped early, the when_all operation will complete with "stopped".
// Otherwise, all the child operations have completed successfully, so the
// when_all operation will complete successfully with a tuple of the child
// senders' values.
Rcvr rcvr_;
broadcast_stop_source ssource_{};
std::atomic<int> count_{sizeof...(Vals)}; // zero when all child operations have completed.
std::atomic<disposition> disp_{disposition::value};
std::tuple<std::optional<Vals>...> vals_{};
std::exception_ptr eptr_{};
};
// This is environment for the receivers used to connect when_all's child senders.
// It exposes the when_all sender's stop_token, used to communicate stop requests
// to the child operations.
template <class Rcvr, class... Vals>
struct when_all_environment {
when_all_state<Rcvr, Vals...>* state_;
// respond to the get_stop_token query with the stop token of when_all's operation.
auto query(get_stop_token_t) const noexcept -> std::stop_token {
return state_->ssource_.get_token();
}
// forward all other queries to the parent receiver's environment
auto query(auto tag) const noexcept -> decltype(tag(state_->rcvr_.get_env())) {
return tag(state_->rcvr_.get_env());
}
};
template <std::size_t I, class Rcvr, class... Vals>
struct when_all_receiver {
template <class U>
void set_value(U&& val) noexcept {
// Emplace the result of the I-th child operation
if (auto eptr = try_emplace(std::get<I>(state_.vals_), std::forward<U>(val)))
return set_error(std::move(eptr));
state_.complete();
}
void set_error(std::exception_ptr eptr) noexcept {
// If there hasn't already been an error in another child...
if (disposition::error != state_.disp_.exchange(disposition::error)) {
state_.eptr_ = std::move(eptr); // save the error, and
state_.ssource_.request_stop(); // ask the other operations to stop early.
}
state_.complete();
}
void set_stopped() noexcept {
disposition expected = disposition::value;
// Set the disposition to stopped if it isn't already stopped or error:
if (state_.disp_.compare_exchange_strong(expected, disposition::stopped))
state_.ssource_.request_stop(); // Ask the other children to stop early
state_.complete();
}
auto get_env() const noexcept {
return when_all_environment<Rcvr, Vals...>{&state_};
}
when_all_state<Rcvr, Vals...>& state_;
};
template <class Rcvr, class CvSndrs, class Is = _tuple_indices<CvSndrs>>
struct when_all_operation;
template <class Rcvr, class CvSndrs, std::size_t... Is>
struct when_all_operation<Rcvr, CvSndrs, std::index_sequence<Is...>> : private immovable {
// Connect all the child senders with receivers that perform bookkeeping for the when_all operation
when_all_operation(CvSndrs&& sndrs, Rcvr rcvr)
: state_{rcvr}
, child_operations_(connect_emplace{std::get<Is>(std::forward<CvSndrs>(sndrs)), child_receiver<Is>{state_}}...)
{}
void start() noexcept {
// Register a stop callback with the receiver's stop token
state_.ssource_.register_with(get_stop_token(state_.rcvr_.get_env()));
// If stop was already requested on the receiver's stop token, then the callback
// has just been executed and state_.ssource_.request_stop() has been called.
if (state_.ssource_.stop_requested()) {
// No need to start the child operations. Send a stopped signal.
state_.rcvr_.set_stopped();
} else {
// start all child operations in order:
std::apply([](auto&... ops) { (ops.start(), ...); }, child_operations_);
}
}
private:
template <std::size_t I>
using child_sender = _tuple_element_t<I, CvSndrs>;
template <std::size_t I>
using child_receiver = when_all_receiver<I, Rcvr, sender_result_t<child_sender<Is>>...>;
// The state_ member is where values and errors get written, and where the metadata
// lives for the when_all operation.
when_all_state<Rcvr, sender_result_t<child_sender<Is>>...> state_;
// This tuple holds the operation states for all the child senders.
std::tuple<connect_result_t<child_sender<Is>, child_receiver<Is>>...> child_operations_;
};
template <class CvSndrs, class Rcvr>
when_all_operation(CvSndrs&&, Rcvr) -> when_all_operation<Rcvr, CvSndrs>;
// The when_all sender just stores the child senders and publishes
// the sender's completion result.
template <class... Sndrs>
struct when_all_sender {
using result_t = std::tuple<sender_result_t<Sndrs>...>;
template <class Rcvr>
auto connect(Rcvr rcvr) {
return when_all_operation{sndrs_, rcvr};
}
std::tuple<Sndrs...> sndrs_;
};
template <class... Sndrs>
auto when_all(Sndrs... sndrs) {
return when_all_sender<Sndrs...>{{sndrs...}};
}
///////////////////////////////////////////
// run_loop execution context
///////////////////////////////////////////
class run_loop : immovable {
struct task : private immovable {
task* next_ = this;
virtual void execute() {}
};
template <class Rcvr>
struct operation : task {
Rcvr rcvr_;
run_loop* loop_;
operation(Rcvr rcvr, run_loop* loop)
: rcvr_(rcvr), loop_(loop) {}
void execute() override final {
rcvr_.set_value(none{});
}
void start() noexcept {
// if stop has been requested, don't even enqueue this work
if (get_stop_token(rcvr_.get_env()).stop_requested())
rcvr_.set_stopped(); // complete immediately
else
loop_->push_back_(this); // enqueue the work for execution
}
};
task head_;
task* tail_ = &head_;
bool finish_ = false;
std::mutex mtx_;
std::condition_variable cv_;
void push_back_(task* op) {
std::unique_lock lk(mtx_);
op->next_ = &head_;
tail_ = tail_->next_= op;
cv_.notify_one();
}
task* pop_front_() {
std::unique_lock lk(mtx_);
if (tail_ == head_.next_)
tail_ = &head_;
cv_.wait(lk, [this]{ return head_.next_ != &head_ || finish_; });
if (head_.next_ == &head_)
return nullptr;
return std::exchange(head_.next_, head_.next_->next_);
}
struct sender {
using result_t = none;
run_loop* loop_;
template <class Rcvr>
operation<Rcvr> connect(Rcvr rcvr) {
return {rcvr, loop_};
}
};
struct scheduler {
run_loop* loop_;
bool operator==(scheduler const&) const = default;
sender schedule() {
return {loop_};
}
};
public:
void run() {
while (auto* op = pop_front_())
op->execute();
}
scheduler get_scheduler() {
return {this};
}
void finish() {
std::unique_lock lk(mtx_);
finish_ = true;
cv_.notify_all();
}
};
///////////////////////////////////////////
// sync_wait() sender consumer
///////////////////////////////////////////
struct sync_wait_state {
run_loop loop;
std::exception_ptr eptr;
};
struct sync_wait_environment {
// sync_wait makes a run_loop scheduler available to child
// operations for scheduling work on the waiting thread.
auto query(get_scheduler_t) const noexcept {
return state_->loop.get_scheduler();
}
sync_wait_state* state_;
};
template <class T>
struct sync_wait_receiver {
sync_wait_state& state_;
std::optional<T>& value_;
void set_value(T val) noexcept {
if (auto eptr = try_emplace(value_, val))
return set_error(eptr);
state_.loop.finish(); // tell the run_loop to wind down
}
void set_error(std::exception_ptr eptr) noexcept {
state_.eptr = eptr;
state_.loop.finish();
}
void set_stopped() noexcept {
state_.loop.finish();
}
sync_wait_environment get_env() const noexcept {
return {&state_};
}
};
template <class Sndr, class T = sender_result_t<Sndr>>
std::optional<T> sync_wait(Sndr snd) {
sync_wait_state state;
std::optional<T> value;
auto op = snd.connect(sync_wait_receiver<T>{state, value});
op.start(); // start the async operation
state.loop.run(); // drive the run_loop until completion
if (state.eptr)
std::rethrow_exception(state.eptr);
return value;
}
///////////////////////////////////////////
// thread_context execution context
///////////////////////////////////////////
class thread_context : immovable {
run_loop loop_;
std::thread thread_;
public:
thread_context()
: thread_([this]{
std::printf("worker thread: %s\n", get_thread_id().c_str());
loop_.run();
})
{}
void finish() {
loop_.finish();
}
void join() {
thread_.join();
}
auto get_scheduler() {
return loop_.get_scheduler();
}
};
// //
// // start test code
// //
// int main() {
// thread_context worker;
// std::printf("main thread: %s\n", get_thread_id().c_str());
// auto log = [](auto val) {
// std::printf("Running task on thread: %s\n", get_thread_id().c_str());
// return val;
// };
// auto task1 = worker.get_scheduler().schedule();
// auto task2 = then(task1, [](auto) { return 42; });
// auto task3 = then(task2, log);
// auto task4 = just(42);
// auto task5 = then(task4, log);
// auto [i,j] = sync_wait(when_all(task3, task5)).value();
// std::printf("result: {%d, %d}\n", i, j);
// worker.finish();
// worker.join();
// }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment