Skip to content

Instantly share code, notes, and snippets.

@DarinM223
Created January 22, 2021 14:40
Show Gist options
  • Save DarinM223/fdc22e19f17d4bc2c7440d41697466ba to your computer and use it in GitHub Desktop.
Save DarinM223/fdc22e19f17d4bc2c7440d41697466ba to your computer and use it in GitHub Desktop.
Playing around with recursion schemes
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
module Main where
import Data.Foldable (fold)
import Data.Functor.Foldable (cata, embed, hylo)
import Data.Functor.Foldable.TH (MakeBaseFunctor (makeBaseFunctor))
import Data.Monoid (Sum (Sum, getSum))
data Expr a = Lit a
| Add (Expr a) (Expr a)
| Sub (Expr a) (Expr a)
| Mult (Expr a) (Expr a)
| Divide (Expr a) (Expr a)
| Expr a :* [Expr a]
deriving Show
$(makeBaseFunctor ''Expr)
countLits :: Expr a -> Int
countLits = getSum . cata go
where
go (LitF _) = Sum 1
go other = fold other
-- Comparison countLits without using a catamorphism.
-- A lot of duplicate cases that have to be handled here!
countLits' :: Expr a -> Int
countLits' (Lit _) = 1
countLits' (Add a b) = countLits' a + countLits' b
countLits' (Sub a b) = countLits' a + countLits' b
countLits' (Mult a b) = countLits' a + countLits' b
countLits' (Divide a b) = countLits' a + countLits' b
countLits' (e :* es) = countLits' e + sum (countLits' <$> es)
evalExpr :: Expr Int -> Int
evalExpr = cata go
where
go (LitF a) = a
go (AddF a b) = a + b
go (SubF a b) = a - b
go (MultF a b) = a * b
go (DivideF a b) = a `quot` b
go (e :*$ es) = e * sum es
replaceDividesWithSubs :: Expr a -> Expr a
replaceDividesWithSubs = cata go
where
go (DivideF a b) = Sub a b
go other = embed other
data Fib a = Num a | Rec (Fib a) (Fib a) deriving Show
$(makeBaseFunctor ''Fib)
fib :: Int -> Int
fib = hylo drain build
where
-- Anamorphism builds data structure from value.
build 0 = NumF 0
build 1 = NumF 1
build n = RecF (n - 1) (n - 2)
-- Catamorphism drains data structure into value.
drain (NumF n) = n
drain (RecF a b) = a + b
main :: IO ()
main = do
let exp = Add (Sub (Lit 1) (Lit 3)) (Mult (Divide (Lit 5) (Lit 7)) (Lit 10))
exps = [Sub (Lit 1) (Lit 1), Add (Divide (Lit 3) (Lit 3)) (Lit 5), Lit 10]
exp' = exp :* exps
print $ countLits exp'
print $ countLits' exp'
print $ evalExpr exp'
print $ replaceDividesWithSubs exp'
print $ fib 20 -- Should print 6765
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment