Last active
February 27, 2016 22:17
-
-
Save Bananattack/1e2d3bbbf80f9ab63779 to your computer and use it in GitHub Desktop.
a C++14 single-header variant type. Allows defining type-safe tagged unions with pattern-matching/visiting. probably has some implementation mistakes. Might use this for wiz, my high-level assembly language project.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef WIZ_VARIANT_H | |
#define WIZ_VARIANT_H | |
#include <cstddef> | |
#include <type_traits> | |
#include <utility> | |
namespace wiz { | |
template<std::size_t... Ns> | |
struct max_value; | |
template<std::size_t N, std::size_t... Ns> | |
struct max_value<N, Ns...> { | |
static const std::size_t value = max_value<Ns...>::value > N | |
? max_value<Ns...>::value | |
: N; | |
}; | |
template<> | |
struct max_value<> { | |
static const std::size_t value = 0; | |
}; | |
template<typename U, typename... Ts> | |
struct type_tag; | |
template<typename U, typename T, typename... Ts> | |
struct type_tag<U, T, Ts...> { | |
static const int value = type_tag<U, Ts...>::value >= 0 | |
? type_tag<U, Ts...>::value + 1 | |
: -1; | |
}; | |
template<typename U, typename... Ts> | |
struct type_tag<U, U, Ts...> { | |
static const int value = 0; | |
}; | |
template<typename U> | |
struct type_tag<U> { | |
static const int value = -1; | |
}; | |
template<int N, typename... Ts> | |
struct tag_dispatcher_; | |
template<int N, typename T, typename... Ts> | |
struct tag_dispatcher_<N, T, Ts...> { | |
static void copy(int tag, void* dest, const void* src) { | |
if(tag == N) { | |
new (dest) T(*reinterpret_cast<const T*>(src)); | |
} else { | |
tag_dispatcher_<N + 1, Ts...>::copy(tag, dest, src); | |
} | |
} | |
static void move(int tag, void* dest, void* src) { | |
if(tag == N) { | |
new (dest) T(std::move(*reinterpret_cast<T*>(src))); | |
} else { | |
tag_dispatcher_<N + 1, Ts...>::move(tag, dest, src); | |
} | |
} | |
static void destroy(int tag, void* data) { | |
if(tag == N) { | |
reinterpret_cast<T*>(data)->~T(); | |
} else { | |
tag_dispatcher_<N + 1, Ts...>::destroy(tag, data); | |
} | |
} | |
template<typename R, typename F> | |
static R apply(int tag, F&& f, const void* data) { | |
if(tag == N) { | |
return std::forward<F>(f)(*reinterpret_cast<const T*>(data)); | |
} else { | |
return tag_dispatcher_<N + 1, Ts...>::template apply<R, F>(tag, std::forward<F>(f), data); | |
} | |
} | |
}; | |
template<int N> | |
struct tag_dispatcher_<N> { | |
static void copy(int tag, void* dest, const void* src) {} | |
static void move(int tag, void* dest, void* src) {} | |
static void destroy(int tag, void* data) {} | |
template<typename R, typename F> | |
static R apply(int tag, F&& f, const void* data) { | |
return R(); | |
} | |
}; | |
template<typename... Ts> | |
using tag_dispatcher = tag_dispatcher_<0, Ts...>; | |
template<typename... Fs> | |
struct overload; | |
template <typename F> | |
struct overload<F> { | |
public: | |
overload(F&& f) : f(std::forward<F>(f)) {} | |
template<typename... Ts> | |
auto operator()(Ts&&... args) const | |
-> decltype(std::declval<F>()(std::forward<Ts>(args)...)) { | |
return f(std::forward<Ts>(args)...); | |
} | |
private: | |
F f; | |
}; | |
template <typename F, typename... Fs> | |
struct overload<F, Fs...> : overload<F>, overload<Fs...> { | |
using overload<F>::operator(); | |
using overload<Fs...>::operator(); | |
overload(F&& f, Fs&&... fs) : | |
overload<F>(std::forward<F>(f)), | |
overload<Fs...>(std::forward<Fs>(fs)...) {} | |
}; | |
template<typename... Ts> | |
class variant { | |
public: | |
variant() = delete; | |
template<typename U> | |
variant(const U& value) | |
: tag(type_tag<U, Ts...>::value) { | |
static_assert_valid_type<U>(); | |
new(&data) U(value); | |
} | |
variant(const variant& other) : tag(other.tag) { | |
tag_dispatcher<Ts...>::copy(tag, &data, &other.data); | |
} | |
variant(variant&& other) : tag(other.tag) { | |
tag_dispatcher<Ts...>::move(tag, &data, &other.data); | |
} | |
~variant() { | |
tag_dispatcher<Ts...>::destroy(tag, &data); | |
} | |
variant& operator =(const variant& other) { | |
tag_dispatcher<Ts...>::destroy(tag, &data); | |
tag = other.tag; | |
tag_dispatcher<Ts...>::copy(tag, &data, &other.data); | |
return *this; | |
} | |
variant& operator =(variant&& other) { | |
tag_dispatcher<Ts...>::destroy(tag, &data); | |
tag = other.tag; | |
tag_dispatcher<Ts...>::move(tag, &data, &other.data); | |
return *this; | |
} | |
int which() const { return tag; } | |
template<typename U> | |
bool is() const { | |
static_assert_valid_type<U>(); | |
return type_tag<U, Ts...>::value == tag; | |
} | |
template<typename U> | |
U& get() { | |
static_assert_valid_type<U>(); | |
return *reinterpret_cast<U*>(&data); | |
} | |
template<typename U> | |
const U& get() const { | |
static_assert_valid_type<U>(); | |
return *reinterpret_cast<const U*>(&data); | |
} | |
template<typename R, typename F> | |
R apply(F&& f) const { | |
return tag_dispatcher<Ts...>::template apply<R, F>(tag, std::forward<F>(f), &data); | |
} | |
template<typename R, typename F, typename... Fs> | |
R apply(F&& f, Fs&&... fs) const { | |
using overload_type = overload<F, Fs...>; | |
return tag_dispatcher<Ts...>::template apply<R, overload_type>( | |
tag, | |
std::forward<overload_type>( | |
overload_type(std::forward<F>(f), std::forward<Fs>(fs)...)), | |
&data); | |
} | |
private: | |
template<typename U> | |
static void static_assert_valid_type() { | |
static_assert(type_tag<U, Ts...>::value >= 0, "variant does not support the provided type."); | |
} | |
using data_type = typename std::aligned_storage< | |
max_value<sizeof(Ts)...>::value, | |
max_value<alignof(Ts)...>::value>::type; | |
int tag; | |
data_type data; | |
}; | |
} | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// A small example program, using the variant for representing a simple expression tree. | |
#include <memory> | |
#include <iostream> | |
#include <wiz/variant.h> | |
struct binary_operator_expression; | |
struct number_expression; | |
typedef wiz::variant<binary_operator_expression, number_expression> expression_variant; | |
struct binary_operator_expression { | |
enum class operation_type { | |
add, | |
subtract, | |
multiply, | |
divide, | |
modulo | |
}; | |
binary_operator_expression(operation_type operation, const std::shared_ptr<expression_variant>& left, const std::shared_ptr<expression_variant>& right) | |
: operation(operation), left(left), right(right) {} | |
operation_type operation; | |
std::shared_ptr<expression_variant> left; | |
std::shared_ptr<expression_variant> right; | |
}; | |
struct number_expression { | |
number_expression(std::size_t value) | |
: value(value) {} | |
std::size_t value; | |
}; | |
void dump(std::ostream& out, expression_variant& expr) { | |
expr.apply<void>( | |
[&](const binary_operator_expression& expr) { | |
out << "("; | |
dump(out, *expr.left); | |
out << " "; | |
switch(expr.operation) { | |
case binary_operator_expression::operation_type::add: out << "+"; break; | |
case binary_operator_expression::operation_type::subtract: out << "-"; break; | |
case binary_operator_expression::operation_type::multiply: out << "*"; break; | |
case binary_operator_expression::operation_type::divide: out << "/"; break; | |
case binary_operator_expression::operation_type::modulo: out << "%"; break; | |
default: out << "???"; break; | |
} | |
out << " "; | |
dump(out, *expr.right); | |
out << ")"; | |
}, | |
[&](const number_expression& expr) { | |
out << expr.value; | |
}); | |
} | |
int evaluate(expression_variant& expr) { | |
return expr.apply<int>( | |
[](const binary_operator_expression& expr) { | |
int left = evaluate(*expr.left); | |
int right = evaluate(*expr.right); | |
switch(expr.operation) { | |
case binary_operator_expression::operation_type::add: return left + right; | |
case binary_operator_expression::operation_type::subtract: return left - right; | |
case binary_operator_expression::operation_type::multiply: return left * right; | |
case binary_operator_expression::operation_type::divide: return left / right; | |
case binary_operator_expression::operation_type::modulo: return left % right; | |
default: return 0; | |
} | |
}, | |
[](const number_expression& expr) { | |
return expr.value; | |
}); | |
} | |
int main() { | |
auto expr = expression_variant(binary_operator_expression( | |
binary_operator_expression::operation_type::subtract, | |
std::make_shared<expression_variant>(binary_operator_expression( | |
binary_operator_expression::operation_type::add, | |
std::make_shared<expression_variant>(number_expression(2)), | |
std::make_shared<expression_variant>( | |
binary_operator_expression( | |
binary_operator_expression::operation_type::multiply, | |
std::make_shared<expression_variant>(number_expression(240)), | |
std::make_shared<expression_variant>(number_expression(90)))))), | |
std::make_shared<expression_variant>(number_expression(4)))); | |
std::cout << "input: "; | |
dump(std::cout, expr); | |
std::cout << std::endl; | |
std::cout << "output: " << evaluate(expr) << std::endl; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment