-
-
Save epicallan/6ebcfce1796717d3882ce4e3f58c4ca2 to your computer and use it in GitHub Desktop.
Continuation Passing Style in Haskell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE InstanceSigs #-} | |
import Control.Applicative | |
import Control.Monad | |
import Control.Monad.Trans.Writer | |
-- Here is a direct style pythagoras function | |
-- There are two noticeable things in the function body. | |
-- 1. Evaluation order of x * x vs y * y is unknown/implicit. | |
-- 2. We don't care what happens to the final value (implicit continuation). | |
pyth :: (Floating a) => a -> a -> a | |
pyth x y = sqrt (x * x + y * y) | |
-- Let us try writing a continuation passing style pythagoras. | |
-- But first, we need to define (+), (*), and sqrt in CPS as well. | |
addCC :: (Floating a) => a -> a -> (a -> r) -> r | |
addCC x y k = k (x + y) | |
multCC :: (Floating a) => a -> a -> (a -> r) -> r | |
multCC x y k = k (x * y) | |
sqrtCC :: (Floating a) => a -> (a -> r) -> r | |
sqrtCC y k = k (sqrt y) | |
-- It is now clear that x will be multiplied first. | |
pythCC :: (Floating a) => a -> a -> (a -> r) -> r | |
pythCC x y k = multCC x x (\r1 -> | |
multCC y y (\r2 -> | |
addCC r1 r2 (\r3 -> | |
sqrtCC r3 k))) | |
{- | |
Notice the repeating (a -> r) -> r. | |
We can abstract it with a new type Cont. | |
Cont r a is a function that passes intermediate value of type a to | |
a continuation and return a final value r. | |
-} | |
newtype Cont r a = Cont { runCont :: (a -> r) -> r } | |
-- Cont r is a Functor | |
instance Functor (Cont r) where | |
fmap :: (a -> b) -> Cont r a -> Cont r b | |
fmap f c = Cont $ \k -> runCont c (k . f) | |
-- Cont r is an Applicative | |
instance Applicative (Cont r) where | |
pure :: a -> Cont r a | |
pure a = Cont $ \k -> k a | |
(<*>) :: Cont r (a -> b) -> Cont r a -> Cont r b | |
c1 <*> c2 = Cont $ \k -> runCont c1 (\r1 -> runCont c2 (k . r1)) | |
-- Cont r is a Monad | |
instance Monad (Cont r) where | |
return :: a -> Cont r a | |
return = pure | |
(>>=) :: Cont r a -> (a -> Cont r b) -> Cont r b | |
c >>= f = Cont $ \k -> runCont c (\r -> runCont (f r) k) | |
(+&) :: (Floating a) => a -> a -> Cont r a | |
x +& y = return (x + y) | |
(*&) :: (Floating a) => a -> a -> Cont r a | |
x *& y = return (x * y) | |
sqrtCC2 :: (Floating a) => a -> Cont r a | |
sqrtCC2 x = return (sqrt x) | |
pythCC2 :: (Floating a) => a -> a -> Cont r a | |
pythCC2 x y = do | |
x2 <- x *& x | |
y2 <- y *& y | |
xy <- x +& y | |
sqrtCC2 xy | |
callCC :: ((a -> r) -> r) -> Cont r a | |
callCC = Cont | |
data Tree a = Leaf | |
| Node a (Tree a) (Tree a) | |
deriving (Show, Eq) | |
isBalanced :: (Show a) => Tree a -> Cont r (Bool, Int) | |
isBalanced Leaf = return (True, 0) | |
isBalanced (Node x l r) = do | |
(lb, lh) <- isBalanced l | |
if lb then do | |
(rb, rh) <- isBalanced r | |
return (rb && (abs (lh - rh) <= 1), 1 + max lh rh) | |
else return (False, lh) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment