Skip to content

Instantly share code, notes, and snippets.

@ramunas
Last active February 22, 2023 13:20
Show Gist options
  • Save ramunas/b23026c6c3b32ca73e5a5293a10f211e to your computer and use it in GitHub Desktop.
Save ramunas/b23026c6c3b32ca73e5a5293a10f211e to your computer and use it in GitHub Desktop.
Sum/Either type in C++ 11
#include <iostream>
#include <functional>
using namespace std;
template <typename A, typename B>
class Either {
private:
// In a language with lambdas, a disjoint union can be implemented using
// Church's encoding in fairly straightforward way. In this encoding, the
// constructors may be implemented as
//
// left(x) := λ l. λ r. l(x), and
// right(x) := λ l. λ r. r(x).
//
// Then, the destructor/choice is
//
// choice(d, f, g) := d(f, g)
//
// For example, we can compute choice(left(1), f, g) = f(1), and
// choice(right(2), f, g) = g(2), which is exactly a property of disjoint
// union.
//
// We wish left and right functions to be polymorphic. However, in C++11,
// left and right cannot be given a polymorphic type. E.g., left would be
// required to have, where A and B are type parameters,
// template <class Ret> T(function<Ret(A)>, function<Ret<B>>)
// but this is not a type, as it depend on the Ret type parameter.
//
// To avoid this dependency, we instead use continuations (return type
// void) that we put in a right context with appropriate polymorphic type.
// The continuations then set the return type. See the choice method.
//
function< void(function<void(A)>, function<void(B)>) > coprod;
Either(function< void(function<void(A)>, function<void(B)>)> f) : coprod(f) {}
public:
Either() = default;
Either(const Either &) = default;
Either(Either &&) = default;
Either& operator=(const Either &) = default;
Either& operator=(Either &&) = default;
inline static Either<A,B> left(A data) {
return Either([data](function<void(A)> left, function<void(B)> right) {
left(data);
});
}
inline static Either<A,B> right(B data) {
return Either([data](function<void(A)> left, function<void(B)> right) {
right(data);
});
}
// requires unary and copy/move constructor for C
template <typename C>
inline C choice(function<C(A)> f, function<C(B)> g) {
C c;
coprod([&c, f](A a) { c = f(a); }, [&c, g](B b) { c = g(b); });
return c;
}
inline void choice(function<void(A)> f, function<void(B)> g) {
coprod(f, g);
}
inline bool is_left() {
return choice<bool>([](A _) { return true; }, [](B _) { return false; });
}
inline bool is_right() {
return !is_left();
}
inline A get_left() {
return choice<A>([](A x) -> A { return x; }, nullptr);
}
inline B get_right() {
return choice<B>(nullptr, [](B x) -> B { return x; });
}
};
template <typename ...Types> struct Sum {
template <typename ...Ts> struct IteratedSum;
template <typename A> struct IteratedSum<A> { typedef A type; };
template <typename A, typename ...Ts> struct IteratedSum<A, Ts...> {
typedef Either<A, typename IteratedSum<Ts...>::type> type;
};
template <int N, typename ...Ts> struct NthType;
template <typename T, typename ...Ts> struct NthType<0, T, Ts...> { typedef T type; };
template <int N, typename T, typename ...Ts> struct NthType<N, T, Ts...> { typedef typename NthType<N-1,Ts...>::type type; };
template <int N, typename ...Ts> struct SumInject;
template <typename T, typename S, typename ...Ts> struct SumInject<0, T, S, Ts...> {
static typename IteratedSum<T, S, Ts...>::type inj(T a) {
return Either<T, typename IteratedSum<S, Ts...>::type>::left(a);
}
};
template <typename T> struct SumInject<0, T> {
static T inj(T a) {
return a;
}
};
template <int N, typename T, typename ...Ts> struct SumInject<N, T, Ts...> {
static typename IteratedSum<T, Ts...>::type inj(typename NthType<N, T, Ts...>::type a) {
return Either<T, typename IteratedSum<Ts...>::type>::right( SumInject<N-1,Ts...>::inj(a) );
}
};
template <int N, typename ...Ts> struct SumIs;
template <typename T>
struct SumIs<0, T> {
static bool is(T) {
return true;
}
};
template <typename T, typename S, typename ...Ts>
struct SumIs<0, T, S, Ts...> {
static bool is(typename IteratedSum<T, S, Ts...>::type sum) {
return sum.is_left();
}
};
template <int N, typename T, typename ...Ts>
struct SumIs<N, T, Ts...> {
static bool is(typename IteratedSum<T, Ts...>::type sum) {
return sum.is_right() && SumIs<N-1,Ts...>::is(sum.get_right());
}
};
template <int N, typename ...Ts> struct SumGet;
template <typename T>
struct SumGet<0, T> {
static T get(T x) {
return x;
}
};
template <typename T, typename S, typename ...Ts>
struct SumGet<0, T, S, Ts...> {
static T get(typename IteratedSum<T, S, Ts...>::type sum) {
return sum.get_left();
}
};
template <int N, typename T, typename ...Ts>
struct SumGet<N, T, Ts...> {
static typename NthType<N, T, Ts...>::type get(typename IteratedSum<T, Ts...>::type sum) {
return SumGet<N-1, Ts...>::get(sum.get_right());
}
};
typedef typename IteratedSum<Types...>::type SumType;
SumType sum;
template <int N> static Sum<Types...> inj(typename NthType<N, Types...>::type a) {
return Sum { SumInject<N, Types...>::inj(a) };
}
template <int N> bool is() {
return SumIs<N, Types...>::is(sum);
}
template <int N> typename NthType<N, Types...>::type get() {
return SumGet<N, Types...>::get(sum);
}
};
/*
*
* Examples
*
*/
template <typename T>
struct BinaryTree {
typedef pair<BinaryTree<T>, BinaryTree<T>> Node;
typedef Either<T, Node> Tree;
Tree tree;
bool is_node() { return tree.is_right(); }
Node get_node() { return tree.get_right(); }
static BinaryTree node(BinaryTree a, BinaryTree b) { return BinaryTree { Tree::right(Node(a,b)) }; }
bool is_leaf() { return tree.is_left(); }
Node get_leaf() { return tree.get_left(); }
static BinaryTree leaf(T a) { return BinaryTree { Tree::left(a) }; }
int depth() {
if (is_leaf())
return 1;
auto t = get_node();
return std::max(t.first.depth(), t.second.depth()) + 1;
}
};
void tree_example() {
auto t = BinaryTree<int>::node(BinaryTree<int>::leaf(1), BinaryTree<int>::node(BinaryTree<int>::leaf(2), BinaryTree<int>::leaf(2)));
cout << "Tree depth " << t.depth() << endl;
}
struct Arith {
typedef int Const;
typedef pair<Arith,Arith> Add;
typedef pair<Arith,Arith> Sub;
typedef Sum<Const, Add, Sub> Expr;
Expr expr;
static Arith konst(int x) { return Arith { Expr::inj<0>(x) }; }
bool is_const() { return expr.is<0>(); }
Const get_const() { return expr.get<0>(); }
static Arith add(Arith a, Arith b) { return Arith { Expr::inj<1>(Add(a,b)) }; }
bool is_add() { return expr.is<1>(); }
Add get_add() { return expr.get<1>(); }
static Arith sub(Arith a, Arith b) { return Arith { Expr::inj<2>(Sub(a,b)) }; }
bool is_sub() { return expr.is<2>(); }
Add get_sub() { return expr.get<2>(); }
int eval() {
if (is_const())
return get_const();
if (is_add()) {
auto p = get_add();
return p.first.eval() + p.second.eval();
}
if (is_sub()) {
auto p = get_sub();
return p.first.eval() - p.second.eval();
}
return 0;
}
};
ostream& operator<<(ostream &out, Arith a)
{
if (a.is_const()) {
out << a.get_const();
} else if (a.is_add()) {
out << "(" << a.get_add().first << " + " << a.get_add().second << ")";
} else if (a.is_sub()) {
out << "(" << a.get_sub().first << " - " << a.get_sub().second << ")";
}
return out;
}
void arith_example()
{
auto expr = Arith::add(Arith::sub(Arith::konst(10), Arith::konst(20)), Arith::konst(15));
cout << "Expression " << expr << " evaluates to " << expr.eval() << endl;
}
struct Lam {
typedef int Var;
typedef pair<Lam,Lam> App;
typedef pair<int,Lam> Fun;
typedef Sum<Var, App, Fun> LamExpr;
LamExpr expr;
static Lam var(int x) { return Lam { LamExpr::inj<0>(x) }; }
bool is_var() { return expr.is<0>(); }
int get_var() { return expr.get<0>(); }
static Lam app(Lam f, Lam e) { return Lam { LamExpr::inj<1>(App(f,e)) }; }
bool is_app() { return expr.is<1>(); }
App get_app() { return expr.get<1>(); }
static Lam lam(int x, Lam e) { return Lam { LamExpr::inj<2>(Fun(x,e)) }; }
bool is_lam() { return expr.is<2>(); }
Fun get_lam() { return expr.get<2>(); }
};
int main() {
tree_example();
arith_example();
auto sum = Sum<int, int, double>::inj<0>(232);
cout << sum.is<2>() << endl;
cout << sum.get<0>() << endl;
auto var = Lam::lam(2, Lam::var(2));
auto x1 = Either<int, Either<double, int>>::right( Either<double, int>::right(86) );
x1.choice<int>([](int v) -> int {
cout << "New Left value " << v << endl;
return 0;
}, [](Either<double, int> v) -> int {
cout << "New Right value " << endl;
return 1;
});
if (x1.is_left()) {
cout << "I am taking left " << x1.get_left() << endl;
} else {
cout << "I am taking right " << endl;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment