Last active
October 8, 2023 23:50
-
-
Save mtao/567130fa412eb5ef13fdc599d16e6d8f to your computer and use it in GitHub Desktop.
Generic caching of multiple return types when switching between compile time and runtime polymorphism
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <spdlog/spdlog.h> | |
#include <functional> | |
#include <map> | |
#include <type_traits> | |
#include <utility> | |
#include <variant> | |
// Say you have a few heavy data structures derived from a single base class | |
// that has a value/enum that can be used to identify which derived class you have. | |
// (a simplicial complex class with different derived classes for dimensions 2,3,4) | |
// Furthermore say you have to call some functor on each one, but the functor returns | |
// a different type for each derived class. | |
// (an operation like edge split/collapse) | |
// Finally say you need to cache all of these values so someone can replay the results | |
// This cpp contains a generic-ish means of doing this mess. | |
// (user wants to identify the "reference vertex" of the op in a particular SC to update an attribute on it) | |
// Re-implement unwrap ref in case it doesn't exist with our current compiler | |
// Implementation is the Possible Implementation from: | |
// https://en.cppreference.com/w/cpp/utility/functional/unwrap_reference | |
#if !defined(__cpp_lib_unwrap_ref) | |
namespace std { | |
template <class T> | |
struct unwrap_reference { | |
using type = T; | |
}; | |
template <class U> | |
struct unwrap_reference<std::reference_wrapper<U>> { | |
using type = U&; | |
}; | |
template <class T> | |
struct unwrap_ref_decay : std::unwrap_reference<std::decay_t<T>> {}; | |
template <class T> | |
using unwrap_ref_decay_t = typename unwrap_ref_decay<T>::type; | |
} // namespace std | |
#endif | |
// Declare a base type that has some sort of ID member and a way of identifying | |
// an appropriate derived type | |
struct Input { | |
int type = -1; | |
int id; | |
}; | |
// Some example derived types | |
struct A : public Input { | |
A(int id) : Input{0, id} {} | |
}; | |
struct B : public Input { | |
B(int id) : Input{1, id} {} | |
}; | |
struct C : public Input { | |
C(int id) : Input{2, id} {} | |
}; | |
// My target application's "Input" class is quite heavy and the Input objects | |
// persist for long periods of time relative to what this is being used for, so | |
// I want to use a variant of references rather than values | |
// | |
// Here's a helper definition for making variants of references | |
template <typename... T> | |
using ReferenceWrapperVariant = std::variant<std::reference_wrapper<T>...>; | |
// The reference class for this type | |
using InputVariant = ReferenceWrapperVariant<A, B, C>; | |
InputVariant as_variant(Input& value) { | |
switch (value.type) { | |
case 0: | |
return std::reference_wrapper(static_cast<A&>(value)); | |
case 1: | |
return std::reference_wrapper(static_cast<B&>(value)); | |
case 2: | |
return std::reference_wrapper(static_cast<C&>(value)); | |
default: | |
throw "InvalidInput"; | |
} | |
// This should never happen, just making a dummy to suppress warnings | |
return std::reference_wrapper(reinterpret_cast<A&>(value)); | |
} | |
// A helper class for specifying per-type return types from an input functor | |
// Assumes the argument is the variant type being selected form, all other | |
// arguments are passed in as const references | |
template <typename Functor, typename... Ts> | |
struct ReturnVariantHelper {}; | |
template <typename Functor, typename... VTs, typename... Ts> | |
struct ReturnVariantHelper<Functor, std::variant<VTs...>, Ts...> { | |
// For a specific type in the variant, get the return type | |
template <typename T> | |
using ReturnType = | |
std::decay_t<std::invoke_result_t<Functor, std::unwrap_ref_decay_t<T>&, | |
const Ts&...>>; | |
// Get an overall variant for the types | |
using type = std::variant<ReturnType<VTs>...>; | |
}; | |
// Interface for reading off the return values from data | |
template <typename Functor, typename... OtherArgumentTypes> | |
class ReturnDataStore { | |
public: | |
using TypeHelper = | |
ReturnVariantHelper<Functor, InputVariant, OtherArgumentTypes...>; | |
using ReturnVariant = typename TypeHelper::type; | |
// a pointer to an input and some other arguments | |
using KeyType = std::tuple<const Input*, OtherArgumentTypes...>; | |
auto get_id(const Input& input, const OtherArgumentTypes&... ts) const { | |
// other applications might use a fancier version of get_id | |
return KeyType(&input, ts...); | |
} | |
// Add new data by giving the InputType | |
// InputType is used to make sure the pair of Input/Output is valid and to | |
// extract an id | |
template <typename InputType, typename ReturnType> | |
void add(const InputType& input, ReturnType&& return_data, | |
const OtherArgumentTypes&... args) { | |
using ReturnType_t = std::decay_t<ReturnType>; | |
static_assert(!std::is_same_v<std::decay_t<InputType>, Input>, | |
"Don't pass in a input, use variant/visitor to get its " | |
"derived type"); | |
// if the user passed in a input class lets try re-invoking with a | |
// derived type | |
auto id = get_id(input, args...); | |
using ExpectedReturnType = | |
typename TypeHelper::template ReturnType<InputType>; | |
static_assert(std::is_convertible_v<ReturnType_t, ExpectedReturnType>, | |
"Second argument should be the return value of a Functor " | |
"(or convertible at " | |
"least) "); | |
m_data.emplace(id, | |
ReturnVariant(std::in_place_type_t<ExpectedReturnType>{}, | |
std::forward<ReturnType>(return_data))); | |
} | |
// let user get the variant for a specific Input derivate | |
const auto& get_variant(const Input& input, | |
const OtherArgumentTypes&... ts) const { | |
auto id = get_id(input, ts...); | |
return m_data.at(id); | |
} | |
// get the type specific input | |
template <typename InputType> | |
auto get(const InputType& input, const OtherArgumentTypes&... ts) const { | |
static_assert(!std::is_same_v<std::decay_t<InputType>, Input>, | |
"Don't pass in a input, use variant/visitor to get its " | |
"derived type"); | |
using ExpectedReturnType = | |
typename TypeHelper::template ReturnType<InputType>; | |
return std::get<ExpectedReturnType>(get_variant(input, ts...)); | |
} | |
private: | |
std::map<KeyType, ReturnVariant> m_data; | |
}; | |
template <typename Functor, typename... OtherTypes> | |
class Runner { | |
public: | |
Runner(Functor&& f) : func(f) {} | |
Runner(Functor&& f, std::tuple<OtherTypes...>) : func(f) {} | |
void run(Input& input, const OtherTypes&... ts) { | |
const int id = input.id; | |
auto var = as_variant(input); | |
std::visit( | |
[&](auto& t) { | |
auto& v = t.get(); | |
return_data.add(v, func(v, ts...), ts...); | |
}, | |
var); | |
} | |
ReturnDataStore<Functor, OtherTypes...> return_data; | |
private: | |
const Functor& func; | |
}; | |
template <typename Functor> | |
ReturnDataStore(Functor&& f) -> ReturnDataStore<std::decay_t<Functor>>; | |
template <typename Functor, typename... Ts> | |
Runner(Functor&& f, std::tuple<Ts...>) -> Runner<Functor, std::decay_t<Ts>...>; | |
template <typename Functor> | |
Runner(Functor&& f) -> Runner<Functor>; | |
struct TestFunctor { | |
template <typename T> | |
auto operator()(T& input) const { | |
using TT = std::unwrap_ref_decay_t<T>; | |
return std::tuple<TT, int>(input, input.id); | |
}; | |
}; | |
struct TestFunctor2Args { | |
template <typename T> | |
auto operator()(T& input, int data) const { | |
using TT = std::unwrap_ref_decay_t<T>; | |
return std::tuple<TT, int>(input, input.id * data); | |
}; | |
}; | |
int main(int argc, char* argv[]) { | |
A a(0); | |
B b(2); | |
C c(4); | |
// test calling the functor once | |
{ | |
auto [ap, i] = TestFunctor{}(a); | |
spdlog::info("{},{} = {}", ap.type, ap.id, i); | |
} | |
// create a mono arg | |
Runner r(TestFunctor{}); | |
r.run(a); | |
r.run(b); | |
r.run(c); | |
{ | |
auto [ap, i] = r.return_data.get(a); | |
spdlog::info("{},{} = {}", ap.type, ap.id, i); | |
} | |
{ | |
auto [ap, i] = r.return_data.get(b); | |
spdlog::info("{},{} = {}", ap.type, ap.id, i); | |
} | |
{ | |
auto [ap, i] = r.return_data.get(c); | |
spdlog::info("{},{} = {}", ap.type, ap.id, i); | |
} | |
// try using 2 args | |
Runner r2(TestFunctor2Args{}, std::tuple<int>{}); | |
r2.run(a, 3); | |
r2.run(b, 5); | |
r2.run(c, 7); | |
{ | |
auto [ap, i] = r2.return_data.get(a, 3); | |
spdlog::info("{},{} = {}", ap.type, ap.id, i); | |
} | |
{ | |
auto [ap, i] = r2.return_data.get(b, 5); | |
spdlog::info("{},{} = {}", ap.type, ap.id, i); | |
} | |
{ | |
auto [ap, i] = r2.return_data.get(c, 7); | |
spdlog::info("{},{} = {}", ap.type, ap.id, i); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment