Skip to content

Instantly share code, notes, and snippets.

@NicolasT
Created April 6, 2010 17:19
Show Gist options
  • Save NicolasT/357840 to your computer and use it in GitHub Desktop.
Save NicolasT/357840 to your computer and use it in GitHub Desktop.
# Callable arrow implementation
# Arrow type class functions
class arr(object):
'''Arrow type constructor'''
def __init__(self, fun):
self.fun = fun
def __rshift__(self, other):
'''>> combinator implementation'''
# Similar to >>^ in Haskell
return arr(lambda x: other(self(x)))
def __rrshift__(self, other):
'''>> combinator implementation'''
# Similar to ^>> in Haskell
return arr.__rshift__(other if isinstance(other, arr) else arr(other),
self)
def __call__(self, x):
return self.fun(x)
first = lambda fun: arr(lambda (x, y): (fun(x), y))
swapA = arr(lambda (x, y): (y, x))
second = lambda fun: swapA >> first(fun) >> swapA
'''This is the *** operator in Haskell'''
split = lambda f, g: first(f) >> second(g)
fanoutA = arr(lambda x: (x, x))
'''This is the &&& operator in Haskell'''
fanout = lambda f, g: fanoutA >> split(f, g)
# The identity arrow
returnA = arr(lambda x: x)
# ArrowChoice type class functions
class Either(object):
def __init__(self, obj):
self.obj = obj
def __eq__(self, other):
if type(other) not in (Left, Right):
return NotImplemented
if type(self) != type(other):
return False
return self.obj == other.obj
def __ne__(self, other):
return not self.__eq__(self, other)
class Left(Either): pass
class Right(Either): pass
left = lambda f: and_(f, returnA)
mirror = lambda x: Right(x.obj) if isinstance(x, Left) else Left(x.obj)
mirrorA = arr(mirror)
right = lambda f: mirrorA >> left(f) >> mirrorA
and_ = lambda f, g: or_(lambda x: Left(f(x)), lambda x: Right(g(x)))
or_ = lambda f, g: arr(lambda x: f(x.obj) if isinstance(x, Left) else g(x.obj))
# ArrowApply type class functions
app = lambda (f, x): f(x)
fun1 = lambda x: (x[0] + 1, x[1] + 1)
fun2 = lambda x: x + 3
fun3 = lambda x: x + 4
assert returnA(1) == 1
assert arr(fun2)(1) == 4
assert arr(fun3)(1) == 5
assert (arr(fun2) >> arr(fun3))(1) == 8
assert (first(fun2))((1, 1)) == (4, 1)
assert (second(fun2))((1, 1)) == (1, 4)
assert (split(fun2, fun3))((1, 1)) == (4, 5)
assert (fanout(fun2, fun3))(1) == (4, 5)
fun1A = arr(fun1)
fun2A = arr(fun2)
fun3A = arr(fun3)
a = fun2A >> fanout(fun2, fun3) >> fun1A >> split(fun2, fun3)
assert a(0) == (10, 12)
assert (left(lambda x: x + 1))(Left(1)) == Left(2)
assert (left(lambda x: x + 1))(Right(1)) == Right(1)
assert (right(lambda x: x + 1))(Left(1)) == Left(1)
assert (right(lambda x: x + 1))(Right(1)) == Right(2)
assert (and_(lambda x: x + 1, lambda x: x + 2))(Left(1)) == Left(2)
assert (and_(lambda x: x + 1, lambda x: x + 2))(Right(1)) == Right(3)
assert (or_(lambda x: x + 1, lambda x: x + 2))(Left(1)) == 2
assert (or_(lambda x: x + 1, lambda x: x + 2))(Right(1)) == 3
# Complex example
square = arr(lambda xs: (x ** 2 for x in xs))
remove_odd = arr(lambda xs: (x for x in xs if x % 2 == 0))
sum_ = arr(lambda xs: reduce(lambda a, b: a + b, xs, 0))
arrow = remove_odd >> square >> sum_
print 'Sum of squares of even numbers from 1 to 5:', arrow(xrange(1, 6))
replace_fizzbuzz = lambda x: Left(x) if x % 15 != 0 else Right('fizzbuzz')
replace_fizz = lambda x: Left(x) if x % 3 != 0 else Right('fizz')
replace_buzz = lambda x: Left(x) if x % 5 != 0 else Right('buzz')
fizzbuzz = replace_fizzbuzz >> or_(replace_fizz, Right) >> \
or_(replace_buzz, Right) >> or_(returnA, returnA)
print '\n'.join('%d %s' % (i, fizzbuzz(i)) for i in xrange(100))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment