Skip to content

Instantly share code, notes, and snippets.

@vittorioromeo
Created March 24, 2017 14:10
Show Gist options
  • Save vittorioromeo/17cfd79d9b832e9748a7fa41b89b5678 to your computer and use it in GitHub Desktop.
Save vittorioromeo/17cfd79d9b832e9748a7fa41b89b5678 to your computer and use it in GitHub Desktop.
#include <type_traits>
#include <utility>
#include <iostream>
#include <tuple>
#include <cassert>
using namespace std;
template <typename... Ts>
struct overloader : Ts...
{
template <typename... TArgs>
constexpr overloader(TArgs&&... xs) : Ts{forward<TArgs>(xs)}...
{
}
using Ts::operator()...;
constexpr void unoverload_into(Ts*... targets)
{
if constexpr((is_move_assignable_v<Ts> && ...))
{
((*targets = static_cast<Ts&&>(*this)), ...);
}
else
{
(targets->~Ts(), ...);
(new (targets) Ts(static_cast<Ts&&>(*this)), ...);
}
}
};
template <typename... Ts>
constexpr auto overload(Ts&&... xs)
{
return overloader<decay_t<Ts>...>{forward<Ts>(xs)...};
}
template <typename... Ts>
class temporary_overloader
{
private:
tuple<Ts*...> _original_fns;
template <typename TF>
constexpr decltype(auto) with_original_fns(TF&& f)
{
return apply(forward<TF>(f), _original_fns);
}
template <typename... TArgs>
constexpr decltype(auto) call_as_if_overloaded(TArgs&&... xs)
{
auto o = with_original_fns([](auto&&... ys){ return overload(move(*ys)...); });
using result_type = decltype(o(forward<TArgs>(xs)...));
if constexpr(is_same_v<result_type, void>)
{
o(forward<TArgs>(xs)...);
with_original_fns([&o](auto&&... ys){ return o.unoverload_into(ys...); });
return;
}
else
{
decltype(auto) result = o(forward<TArgs>(xs)...);
with_original_fns([&o](auto&&... ys){ return o.unoverload_into(ys...); });
return result;
}
}
public:
constexpr temporary_overloader(Ts&... fns) : _original_fns{&fns...}
{
}
template <typename... TArgs>
constexpr decltype(auto) operator()(TArgs&&... xs)
{
return call_as_if_overloaded(forward<TArgs>(xs)...);
}
};
template <typename... Ts>
constexpr auto ref_overload(Ts&&... xs)
{
return temporary_overloader<decay_t<Ts>...>{xs...};
}
int main()
{
auto l0 = [x = false](char) mutable { cout << "CHAR"; auto temp = x; x = true; return temp; };
auto l1 = [x = false](int) mutable { cout << "INT"; auto temp = x; x = true; return temp; };
auto o = ref_overload(l0, l1);
o('a'); // prints "CHAR"
o(0); // prints "INT"
assert(l0('a') == true);
assert(l1(0) == true);
// ---
struct test { constexpr int operator()(int){ return 1; } };
test t;
static_assert(ref_overload(t)(0) == 1);
// Equivalent to:
/*
auto o = overload(l0, l1);
o('a'); // prints "CHAR"
o(0); // prints "INT"
o.unoverload_into(&l0, &l1);
assert(l0('a') == true);
assert(l1(0) == true);
*/
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment