Created
October 27, 2012 17:56
-
-
Save splinterofchaos/3965514 to your computer and use it in GitHub Desktop.
Monad
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <memory> | |
#include <iostream> | |
#include <utility> | |
#include <algorithm> | |
#include <iterator> | |
struct sequence_tag {}; | |
struct pointer_tag {}; | |
template< class X > | |
X category( ... ); | |
template< class S > | |
auto category( const S& s ) -> decltype( std::begin(s), sequence_tag() ); | |
template< class Ptr > | |
auto category( const Ptr& p ) -> decltype( *p, p==nullptr, pointer_tag() ); | |
template< class T > struct Category { | |
using type = decltype( category<T>(std::declval<T>()) ); | |
}; | |
template< class R, class ... X > struct Category< R(&)(X...) > { | |
using type = R(&)(X...); | |
}; | |
template< class T > | |
using Cat = typename Category<T>::type; | |
template< class... > struct Functor; | |
template< class F, class FX, class Fun=Functor< Cat<FX> > > | |
auto fmap( F&& f, FX&& fx ) | |
-> decltype( Fun::fmap( std::declval<F>(), std::declval<FX>() ) ) | |
{ | |
return Fun::fmap( std::forward<F>(f), std::forward<FX>(fx) ); | |
} | |
template< class F, class G > | |
struct Composition { | |
F f; | |
G g; | |
template< class X > | |
auto operator () ( X&& x ) -> decltype( f(g(std::declval<X>())) ) { | |
return f(g(std::forward<X>(x))); | |
} | |
}; | |
// General case: composition | |
template< class Function > struct Functor<Function> { | |
template< class F, class G, class C = Composition<F,G> > | |
static C fmap( F f, G g ) { | |
C( std::move(f), std::move(g) ); | |
} | |
}; | |
template<> struct Functor< sequence_tag > { | |
template< class F, template<class...>class S, class X, | |
class R = typename std::result_of<F(X)>::type > | |
static S<R> fmap( F&& f, const S<X>& s ) { | |
S<R> r; | |
r.reserve( s.size() ); | |
std::transform( std::begin(s), std::end(s), | |
std::back_inserter(r), | |
std::forward<F>(f) ); | |
return r; | |
} | |
}; | |
template<> struct Functor< pointer_tag > { | |
template< class F, template<class...>class Ptr, class X, | |
class R = typename std::result_of<F(X)>::type > | |
static Ptr<R> fmap( F&& f, const Ptr<X>& p ) | |
{ | |
return p != nullptr | |
? Ptr<R>( new R( std::forward<F>(f)(*p) ) ) | |
: nullptr; | |
} | |
}; | |
template< class ... > struct Monad; | |
template< class F, class M, class ...N, class Mo=Monad<Cat<M>> > | |
auto mbind( F&& f, M&& m, N&& ...n ) | |
-> decltype( Mo::mbind(std::declval<F>(), | |
std::declval<M>(),std::declval<N>()...) ) | |
{ | |
return Mo::mbind( std::forward<F>(f), | |
std::forward<M>(m), std::forward<N>(n)... ); | |
} | |
template< class F, class M, class ...N, class Mo=Monad<Cat<M>> > | |
auto mdo( F&& f, M&& m ) | |
-> decltype( Mo::mdo(std::declval<F>(), std::declval<M>()) ) | |
{ | |
return Mo::mdo( std::forward<F>(f), std::forward<M>(m) ); | |
} | |
// The first template argument must be explicit! | |
template< class M, class X, class Mo = Monad<Cat<M>> > | |
M mreturn( X&& x ) { | |
return Mo::template mreturn<M>( std::forward<X>(x) ); | |
} | |
template< template<class...>class M, class X, class Mo = Monad<Cat<M<X>>> > | |
M<X> mreturn( const X& x ) { | |
return Mo::template mreturn<M<X>>( x ); | |
} | |
// Also has explicit template argument. | |
template< class M, class Mo = Monad<Cat<M>> > | |
M mfail() { | |
return Mo::template mfail<M>(); | |
} | |
template< > struct Monad< pointer_tag > { | |
template< class F, template<class...>class Ptr, class X, | |
class R = typename std::result_of<F(X)>::type > | |
static R mbind( F&& f, const Ptr<X>& p ) { | |
return p ? std::forward<F>(f)( *p ) : nullptr; | |
} | |
template< class F, template<class...>class Ptr, | |
class X, class Y, | |
class R = typename std::result_of<F(X,Y)>::type > | |
static R mbind( F&& f, const Ptr<X>& p, const Ptr<Y>& q ) { | |
return p and q ? std::forward<F>(f)( *p, *q ) : nullptr; | |
} | |
template< template< class... > class M, class X, class Y > | |
static M<Y> mdo( const M<X>& mx, const M<Y>& my ) { | |
return mx ? (my ? mreturn<M<Y>>(*my) : nullptr) | |
: nullptr; | |
} | |
template< class M, class X > | |
static M mreturn( X&& x ) { | |
using Y = typename M::element_type; | |
return M( new Y(std::forward<X>(x)) ); | |
} | |
template< class M > | |
static M mfail() { return nullptr; } | |
}; | |
template< > struct Monad< sequence_tag > { | |
template< class F, template<class...>class S, class X, | |
class R = typename std::result_of<F(X)>::type > | |
static R mbind( F&& f, const S<X>& xs ) { | |
R r; | |
for( const X& x : xs ) { | |
auto ys = std::forward<F>(f)( x ); | |
std::move( std::begin(ys), std::end(ys), std::back_inserter(r) ); | |
} | |
return r; | |
} | |
template< class F, template<class...>class S, | |
class X, class Y, | |
class R = typename std::result_of<F(X,Y)>::type > | |
static R mbind( F&& f, const S<X>& xs, const S<Y>& ys ) { | |
R r; | |
for( const X& x : xs ) { | |
for( const Y& y : ys ) { | |
auto zs = std::forward<F>(f)( x, y ); | |
std::move( std::begin(zs), std::end(zs), | |
std::back_inserter(r) ); | |
} | |
} | |
return r; | |
} | |
template< template< class... > class S, class X, class Y > | |
static S<Y> mdo( const S<X>& mx, const S<Y>& my ) { | |
// Note: This is not a strictly correct definition. | |
// It should return my concatenated to itself for every element of mx. | |
return mx.size() ? my : S<Y>{}; | |
} | |
template< class S, class X > | |
static S mreturn( X&& x ) { | |
return S{ std::forward<X>(x) }; // Construct an S of one element. | |
} | |
template< class S > | |
static S mfail() { return S{}; } | |
}; | |
template< class M > | |
M addM( const M& a, const M& b ) { | |
return mbind ( | |
[&]( int x ) { | |
return mbind ( | |
[=]( int y ) { return mreturn<M>(x+y); }, | |
b | |
); | |
}, | |
a | |
); | |
} | |
template< class M > | |
M addM2( const M& a, const M& b ) { | |
return mbind ( | |
[&]( int x ) { | |
return fmap ( | |
[=]( int y ) { return x + y; }, | |
b | |
); | |
}, | |
a | |
); | |
} | |
template< class M, class F > | |
auto operator >>= ( M&& m, F&& f ) | |
-> decltype( mbind(std::declval<F>(),std::declval<M>()) ) | |
{ | |
return mbind( std::forward<F>(f), std::forward<M>(m) ); | |
} | |
template< class M, class F > | |
auto operator >> ( M&& m, F&& f ) | |
-> decltype( mdo(std::declval<M>(),std::declval<F>()) ) | |
{ | |
return mdo( std::forward<M>(m), std::forward<F>(f) ); | |
} | |
template< class F, template<class...>class M, | |
class X, class Y, | |
class R = typename std::result_of<F(X,Y)>::type > | |
M<R> liftM( F&& f, const M<X>& a, const M<Y>& b ) { | |
return a >>= [&]( const X& x ) { | |
return b >>= [&]( const Y& y ) { | |
return mreturn<M>( std::forward<F>(f)(x,y) ); | |
}; | |
}; | |
}; | |
/* | |
* guard<M>(b) = (return True) or mfail(). | |
* | |
* guard prematurely halts an execution based on some bool, b. Note that: | |
* p >> q = q | |
* mfail() >> p = mfail() | |
* nullptr >> p = nullptr -- where p is a unique_ptr. | |
* {} >> v = {} -- where v is a vector. | |
*/ | |
template< template< class... > class M > | |
M<bool> guard( bool b ) { | |
return b ? mreturn<M>(b) : mfail<M<bool>>(); | |
} | |
/* | |
* The above version of guard creates a junk Monad. This may be costly. | |
* This version is an optimal shorthand for guard(b) >> m. | |
*/ | |
template< template< class... > class M, class F, | |
class R = typename std::result_of<F()>::type > | |
M<R> guard( bool b, F&& f ) { | |
return b ? mreturn<M>( std::forward<F>(f)() ) : mfail<M<R>>(); | |
} | |
template< template< class... > class M, class X > | |
M< std::pair<X,X> > uniquePairs( const M<X>& m ) { | |
return mbind ( | |
[]( int x, int y ) -> M< std::pair<X,X> > { | |
// This is a very Haskell-like use of guard. | |
return guard<M>( x != y ) >> mreturn<M>( std::make_pair(x,y) ); | |
}, m, m | |
); | |
} | |
/* alias for mreturn<unique_ptr> */ | |
template< class X > | |
auto Just( X&& x ) -> decltype( mreturn<std::unique_ptr>(std::declval<X>()) ) { | |
return mreturn<std::unique_ptr>( std::forward<X>(x) ); | |
} | |
#include <cmath> | |
// Safe square root. | |
std::unique_ptr<float> sqrt( float x ) { | |
// The more optimized C++-guard. | |
return guard<std::unique_ptr>( x >= 0, [x]{ return std::sqrt(x); } ); | |
// Equivalently, | |
return x >= 0 ? Just( std::sqrt(x) ) : nullptr; | |
} | |
// Safe quadratic root. | |
std::unique_ptr<std::pair<float,float>> qroot( float a, float b, float c ) { | |
return fmap ( | |
[=]( float r /*root*/ ) { | |
return std::make_pair( (-b + r)/(2*a), (-b - r)/(2*a) ); | |
}, | |
sqrt( b*b - 4*a*c ) | |
); | |
} | |
template< class X, class Y > | |
std::ostream& operator << ( std::ostream& os, const std::pair<X,Y>& p ) { | |
os << '(' << p.first << ',' << p.second << ')'; | |
return os; | |
} | |
template< class X > | |
std::ostream& operator << ( std::ostream& os, const std::unique_ptr<X>& p ) { | |
if( p ) | |
os << "Just " << *p; | |
else | |
os << "Nothing"; | |
return os; | |
} | |
int main() { | |
std::unique_ptr<int> p( new int(5) ); | |
auto f = []( int x ) { return Just(-x); }; | |
std::unique_ptr<int> q = mbind( f, p ); | |
std::cout << "q = " << q << std::endl; | |
std::cout << "p+q = " << addM2( p, q ) << std::endl; | |
std::cout << "p+q = " << liftM( std::plus<int>(), p, q ) << std::endl; | |
std::vector<int> v={1,2}, w={3,4}; | |
std::cout << "v+w = { "; | |
auto vw = addM(v,w); | |
std::copy ( | |
std::begin(vw), std::end(vw), | |
std::ostream_iterator<int>(std::cout, " ") | |
); | |
std::cout << '}' << std::endl; | |
{ | |
std::vector<int> v = {1,2,3}; | |
using V = std::vector<std::pair<int,int>>; | |
auto ps = uniquePairs( v ); | |
std::cout << "Unique pairs of [1,2,3]:\n\t"; | |
for( const auto& p : ps ) | |
std::cout << p << ' '; | |
std::cout << std::endl; | |
std::cout << "Unique pairs of Just 5:\n\t" << uniquePairs(p) << std::endl; | |
} | |
std::cout << "The quadratic root of (1,3,-4) = " << qroot(1,3,-4) << std::endl; | |
std::cout << "The quadratic root of (1,0,4) = " << qroot(1,0,4) << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment