Last active
February 22, 2023 13:20
-
-
Save ramunas/b23026c6c3b32ca73e5a5293a10f211e to your computer and use it in GitHub Desktop.
Sum/Either type in C++ 11
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <functional> | |
using namespace std; | |
template <typename A, typename B> | |
class Either { | |
private: | |
// In a language with lambdas, a disjoint union can be implemented using | |
// Church's encoding in fairly straightforward way. In this encoding, the | |
// constructors may be implemented as | |
// | |
// left(x) := λ l. λ r. l(x), and | |
// right(x) := λ l. λ r. r(x). | |
// | |
// Then, the destructor/choice is | |
// | |
// choice(d, f, g) := d(f, g) | |
// | |
// For example, we can compute choice(left(1), f, g) = f(1), and | |
// choice(right(2), f, g) = g(2), which is exactly a property of disjoint | |
// union. | |
// | |
// We wish left and right functions to be polymorphic. However, in C++11, | |
// left and right cannot be given a polymorphic type. E.g., left would be | |
// required to have, where A and B are type parameters, | |
// template <class Ret> T(function<Ret(A)>, function<Ret<B>>) | |
// but this is not a type, as it depend on the Ret type parameter. | |
// | |
// To avoid this dependency, we instead use continuations (return type | |
// void) that we put in a right context with appropriate polymorphic type. | |
// The continuations then set the return type. See the choice method. | |
// | |
function< void(function<void(A)>, function<void(B)>) > coprod; | |
Either(function< void(function<void(A)>, function<void(B)>)> f) : coprod(f) {} | |
public: | |
Either() = default; | |
Either(const Either &) = default; | |
Either(Either &&) = default; | |
Either& operator=(const Either &) = default; | |
Either& operator=(Either &&) = default; | |
inline static Either<A,B> left(A data) { | |
return Either([data](function<void(A)> left, function<void(B)> right) { | |
left(data); | |
}); | |
} | |
inline static Either<A,B> right(B data) { | |
return Either([data](function<void(A)> left, function<void(B)> right) { | |
right(data); | |
}); | |
} | |
// requires unary and copy/move constructor for C | |
template <typename C> | |
inline C choice(function<C(A)> f, function<C(B)> g) { | |
C c; | |
coprod([&c, f](A a) { c = f(a); }, [&c, g](B b) { c = g(b); }); | |
return c; | |
} | |
inline void choice(function<void(A)> f, function<void(B)> g) { | |
coprod(f, g); | |
} | |
inline bool is_left() { | |
return choice<bool>([](A _) { return true; }, [](B _) { return false; }); | |
} | |
inline bool is_right() { | |
return !is_left(); | |
} | |
inline A get_left() { | |
return choice<A>([](A x) -> A { return x; }, nullptr); | |
} | |
inline B get_right() { | |
return choice<B>(nullptr, [](B x) -> B { return x; }); | |
} | |
}; | |
template <typename ...Types> struct Sum { | |
template <typename ...Ts> struct IteratedSum; | |
template <typename A> struct IteratedSum<A> { typedef A type; }; | |
template <typename A, typename ...Ts> struct IteratedSum<A, Ts...> { | |
typedef Either<A, typename IteratedSum<Ts...>::type> type; | |
}; | |
template <int N, typename ...Ts> struct NthType; | |
template <typename T, typename ...Ts> struct NthType<0, T, Ts...> { typedef T type; }; | |
template <int N, typename T, typename ...Ts> struct NthType<N, T, Ts...> { typedef typename NthType<N-1,Ts...>::type type; }; | |
template <int N, typename ...Ts> struct SumInject; | |
template <typename T, typename S, typename ...Ts> struct SumInject<0, T, S, Ts...> { | |
static typename IteratedSum<T, S, Ts...>::type inj(T a) { | |
return Either<T, typename IteratedSum<S, Ts...>::type>::left(a); | |
} | |
}; | |
template <typename T> struct SumInject<0, T> { | |
static T inj(T a) { | |
return a; | |
} | |
}; | |
template <int N, typename T, typename ...Ts> struct SumInject<N, T, Ts...> { | |
static typename IteratedSum<T, Ts...>::type inj(typename NthType<N, T, Ts...>::type a) { | |
return Either<T, typename IteratedSum<Ts...>::type>::right( SumInject<N-1,Ts...>::inj(a) ); | |
} | |
}; | |
template <int N, typename ...Ts> struct SumIs; | |
template <typename T> | |
struct SumIs<0, T> { | |
static bool is(T) { | |
return true; | |
} | |
}; | |
template <typename T, typename S, typename ...Ts> | |
struct SumIs<0, T, S, Ts...> { | |
static bool is(typename IteratedSum<T, S, Ts...>::type sum) { | |
return sum.is_left(); | |
} | |
}; | |
template <int N, typename T, typename ...Ts> | |
struct SumIs<N, T, Ts...> { | |
static bool is(typename IteratedSum<T, Ts...>::type sum) { | |
return sum.is_right() && SumIs<N-1,Ts...>::is(sum.get_right()); | |
} | |
}; | |
template <int N, typename ...Ts> struct SumGet; | |
template <typename T> | |
struct SumGet<0, T> { | |
static T get(T x) { | |
return x; | |
} | |
}; | |
template <typename T, typename S, typename ...Ts> | |
struct SumGet<0, T, S, Ts...> { | |
static T get(typename IteratedSum<T, S, Ts...>::type sum) { | |
return sum.get_left(); | |
} | |
}; | |
template <int N, typename T, typename ...Ts> | |
struct SumGet<N, T, Ts...> { | |
static typename NthType<N, T, Ts...>::type get(typename IteratedSum<T, Ts...>::type sum) { | |
return SumGet<N-1, Ts...>::get(sum.get_right()); | |
} | |
}; | |
typedef typename IteratedSum<Types...>::type SumType; | |
SumType sum; | |
template <int N> static Sum<Types...> inj(typename NthType<N, Types...>::type a) { | |
return Sum { SumInject<N, Types...>::inj(a) }; | |
} | |
template <int N> bool is() { | |
return SumIs<N, Types...>::is(sum); | |
} | |
template <int N> typename NthType<N, Types...>::type get() { | |
return SumGet<N, Types...>::get(sum); | |
} | |
}; | |
/* | |
* | |
* Examples | |
* | |
*/ | |
template <typename T> | |
struct BinaryTree { | |
typedef pair<BinaryTree<T>, BinaryTree<T>> Node; | |
typedef Either<T, Node> Tree; | |
Tree tree; | |
bool is_node() { return tree.is_right(); } | |
Node get_node() { return tree.get_right(); } | |
static BinaryTree node(BinaryTree a, BinaryTree b) { return BinaryTree { Tree::right(Node(a,b)) }; } | |
bool is_leaf() { return tree.is_left(); } | |
Node get_leaf() { return tree.get_left(); } | |
static BinaryTree leaf(T a) { return BinaryTree { Tree::left(a) }; } | |
int depth() { | |
if (is_leaf()) | |
return 1; | |
auto t = get_node(); | |
return std::max(t.first.depth(), t.second.depth()) + 1; | |
} | |
}; | |
void tree_example() { | |
auto t = BinaryTree<int>::node(BinaryTree<int>::leaf(1), BinaryTree<int>::node(BinaryTree<int>::leaf(2), BinaryTree<int>::leaf(2))); | |
cout << "Tree depth " << t.depth() << endl; | |
} | |
struct Arith { | |
typedef int Const; | |
typedef pair<Arith,Arith> Add; | |
typedef pair<Arith,Arith> Sub; | |
typedef Sum<Const, Add, Sub> Expr; | |
Expr expr; | |
static Arith konst(int x) { return Arith { Expr::inj<0>(x) }; } | |
bool is_const() { return expr.is<0>(); } | |
Const get_const() { return expr.get<0>(); } | |
static Arith add(Arith a, Arith b) { return Arith { Expr::inj<1>(Add(a,b)) }; } | |
bool is_add() { return expr.is<1>(); } | |
Add get_add() { return expr.get<1>(); } | |
static Arith sub(Arith a, Arith b) { return Arith { Expr::inj<2>(Sub(a,b)) }; } | |
bool is_sub() { return expr.is<2>(); } | |
Add get_sub() { return expr.get<2>(); } | |
int eval() { | |
if (is_const()) | |
return get_const(); | |
if (is_add()) { | |
auto p = get_add(); | |
return p.first.eval() + p.second.eval(); | |
} | |
if (is_sub()) { | |
auto p = get_sub(); | |
return p.first.eval() - p.second.eval(); | |
} | |
return 0; | |
} | |
}; | |
ostream& operator<<(ostream &out, Arith a) | |
{ | |
if (a.is_const()) { | |
out << a.get_const(); | |
} else if (a.is_add()) { | |
out << "(" << a.get_add().first << " + " << a.get_add().second << ")"; | |
} else if (a.is_sub()) { | |
out << "(" << a.get_sub().first << " - " << a.get_sub().second << ")"; | |
} | |
return out; | |
} | |
void arith_example() | |
{ | |
auto expr = Arith::add(Arith::sub(Arith::konst(10), Arith::konst(20)), Arith::konst(15)); | |
cout << "Expression " << expr << " evaluates to " << expr.eval() << endl; | |
} | |
struct Lam { | |
typedef int Var; | |
typedef pair<Lam,Lam> App; | |
typedef pair<int,Lam> Fun; | |
typedef Sum<Var, App, Fun> LamExpr; | |
LamExpr expr; | |
static Lam var(int x) { return Lam { LamExpr::inj<0>(x) }; } | |
bool is_var() { return expr.is<0>(); } | |
int get_var() { return expr.get<0>(); } | |
static Lam app(Lam f, Lam e) { return Lam { LamExpr::inj<1>(App(f,e)) }; } | |
bool is_app() { return expr.is<1>(); } | |
App get_app() { return expr.get<1>(); } | |
static Lam lam(int x, Lam e) { return Lam { LamExpr::inj<2>(Fun(x,e)) }; } | |
bool is_lam() { return expr.is<2>(); } | |
Fun get_lam() { return expr.get<2>(); } | |
}; | |
int main() { | |
tree_example(); | |
arith_example(); | |
auto sum = Sum<int, int, double>::inj<0>(232); | |
cout << sum.is<2>() << endl; | |
cout << sum.get<0>() << endl; | |
auto var = Lam::lam(2, Lam::var(2)); | |
auto x1 = Either<int, Either<double, int>>::right( Either<double, int>::right(86) ); | |
x1.choice<int>([](int v) -> int { | |
cout << "New Left value " << v << endl; | |
return 0; | |
}, [](Either<double, int> v) -> int { | |
cout << "New Right value " << endl; | |
return 1; | |
}); | |
if (x1.is_left()) { | |
cout << "I am taking left " << x1.get_left() << endl; | |
} else { | |
cout << "I am taking right " << endl; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment