Skip to content

Instantly share code, notes, and snippets.

@mpickering
Created August 30, 2024 09:03
Show Gist options
  • Save mpickering/8ce7a3017113d32d0342718e298d3029 to your computer and use it in GitHub Desktop.
Save mpickering/8ce7a3017113d32d0342718e298d3029 to your computer and use it in GitHub Desktop.
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE GADTs #-}
module Pat3 where
import qualified Language.Haskell.TH as TH (Code, Q,)
import qualified Language.Haskell.TH.Syntax as TH (Lift(..))
import Data.Functor.Identity
type Code = TH.Code TH.Q
data Pat m input f res where
LitPat :: m Int -> Pat m Int res res
-- Do not use Code here
VarPat :: Pat m a (a -> r) r
NilPat :: Pat m [a] r r
ConsPat :: Pat m a f g -> Pat m [a] g r -> Pat m [a] f r
data Case m input res where
Case ::
-- Do not use Code here
Pat m input f r
-- Wrap RHS in m
-> m f -> Case m input r
v1 :: Pat Code [a] (a -> [a] -> r) r
v1 = VarPat `ConsPat` VarPat
defn1 :: Case Code [a] (Maybe (a,[a]))
defn1 = Case v1 [|| (\(a1) a2 -> (Just (a1, a2))) ||]
defn2 :: Case Code [a] (Maybe x)
defn2 = Case NilPat ([|| Nothing ||])
defn3 :: Case Code [Int] (Maybe (Int, [Int]))
defn3 = Case (LitPat [|| 1 ||] `ConsPat` VarPat) ([|| (\a2 -> Just (1, a2)) ||])
match :: Code a -> [Case Code a b] -> Code b
match a b = match_loop a b
match_loop :: Code a -> [Case Code a b] -> Code b
match_loop _ [] = [|| error "no matches" ||]
match_loop c (Case p l : xs) =
[|| case $$(pat c p l) of
Just res -> res
Nothing -> $$(match_loop c xs)
||]
pat :: Code a -> Pat Code a f res -> Code f -> Code (Maybe res)
pat a p l =
case p of
LitPat i -> [|| case $$a of
x | x == $$i -> Just $$l
_ -> Nothing ||]
VarPat -> [|| case $$a of
x -> Just ($$l x) ||]
NilPat -> [|| case $$a of
[] -> Just $$l
_ -> Nothing ||]
ConsPat c1 c2 ->
[|| case $$a of
(x:xs) ->
case $$(pat [|| x ||] c1 l) of
Just r1 -> $$(pat [|| xs ||] c2 [|| r1 ||])
Nothing -> Nothing
[] -> Nothing ||]
-- A staged and unstaged interpreter
class Lang l where
_if :: l Bool -> l a -> l a -> l a
_eq :: l Int -> l Int -> l Bool
_int :: Int -> l Int
_maybe :: l (Maybe a) -> l (a -> b) -> l b -> l b
_list :: l [a] -> l (a -> [a] -> r) -> l r -> l r
_tup :: l a -> l b -> l (a, b)
_cons :: l a -> l [a] -> l [a]
_nil :: l [a]
_just :: l a -> l (Maybe a)
_nothing :: l (Maybe a)
_app :: l (a -> b) -> l a -> l b
_lam :: (l a -> l b) -> l (a -> b)
_error :: String -> l a
_let :: l a -> l (a -> b) -> l b
instance Lang Identity where
_if b t f = if runIdentity b then t else f
_eq e1 e2 = Identity (runIdentity e1 == runIdentity e2)
_int i = Identity i
_maybe m j n = case runIdentity m of
Just x -> Identity (runIdentity j x)
Nothing -> n
_list l cons nil = case runIdentity l of
(x:xs) -> Identity (runIdentity cons x xs)
[] -> nil
_tup l r = Identity (runIdentity l, runIdentity r)
_cons x xs = (:) <$> x <*> xs
_nil = pure []
_just = fmap Just
_nothing = pure Nothing
_app = (<*>)
_lam f = Identity (\a -> runIdentity (f (Identity a)))
_error x = Identity (error x)
_let l f = Identity (let j = runIdentity l in runIdentity f j)
instance Lang Code where
_if b t f = [|| if $$b then $$t else $$f ||]
_eq e1 e2 = [|| $$e1 == $$e2 ||]
_int i = TH.liftTyped i
_maybe m j n = [|| case $$m of
Just x -> $$j x
Nothing -> $$n ||]
_list l cons nil = [|| case $$l of
(x:xs) -> $$cons x xs
[] -> $$nil ||]
_tup l r = [|| ($$l, $$r) ||]
_cons x xs = [|| $$x : $$xs ||]
_nil = [|| [] ||]
_just x = [|| Just $$x ||]
_nothing = [|| Nothing ||]
_app f x = [|| $$f $$x ||]
_lam f = [|| \x -> $$(f [|| x ||]) ||]
_error x = [|| error $$(TH.liftTyped x) ||]
_let l f = [|| let j = $$l in $$f j ||]
-- Functions in target language are represented by functions in the
-- host language.
data Code2 t where
C :: Code t -> Code2 t
F :: (Code2 t -> Code2 r) -> Code2 (t -> r)
-- Interpretation function
i :: Code2 t -> Code t
i (C c1) = c1
i (F f) = [|| \x -> $$(i (f (C [|| x ||]))) ||]
-- Application
a :: Code2 (a -> b) -> Code2 a -> Code2 b
a (C f) x = C [|| $$f $$(i x) ||]
a (F f) x = f x
a2 :: Code2 (a -> b -> c) -> Code2 a -> Code2 b -> Code2 c
a2 c2 x y = a (a c2 x) y
instance Lang Code2 where
_if b t f = C [|| if $$(i b) then $$(i t) else $$(i f) ||]
_eq e1 e2 = C [|| $$(i e1) == $$(i e2) ||]
_int i = C (TH.liftTyped i)
_maybe m j n = C [|| case $$(i m) of
Just x -> $$(i $ a j (C [|| x ||]))
Nothing -> $$(i n) ||]
_list l cons nil = C [|| case $$(i l) of
(x:xs) -> $$(i $ a2 cons (C [|| x ||]) (C [||xs ||]))
[] -> $$(i nil) ||]
_tup l r = C [|| ($$(i l), $$(i r)) ||]
_cons x xs = C [|| $$(i x) : $$(i xs) ||]
_nil = C [|| [] ||]
_just x = C [|| Just $$(i x) ||]
_nothing = C [|| Nothing ||]
_app x y = a x y
_lam f = F f
_error x = C [|| error $$(TH.liftTyped x) ||]
_let l b = C [|| let j = $$(i l) in $$(i $ a b (C [|| j ||])) ||]
-- Testing that nesting lambdas works
test_reduce :: Lang l => l Bool
test_reduce =(_lam (\x -> _lam (\y -> _eq x y))) `_app` _int 0 `_app` _int 1
v1L :: Pat l [a] (a -> [a] -> r) r
v1L = VarPat `ConsPat` VarPat
defn1L :: Lang l => Case l [a] (Maybe (a,[a]))
defn1L = Case v1L (_lam (\x1 -> _lam (\x2 -> _just (_tup x1 x2))))
defn2L :: Lang l => Case l [a] (Maybe x)
defn2L = Case NilPat _nothing
defn3L :: Lang l => Case l [Int] (Maybe (Int, [Int]))
defn3L = Case (LitPat (_int 1) `ConsPat` VarPat) (_lam (\a2 -> _just (_tup (_int 1) a2)))
matchL :: Lang l => l a -> [Case l a b] -> l b
matchL _ [] = _error "no matches"
matchL c (Case p l : xs) =
_maybe (patL c p l) (_lam id) (matchL c xs)
patL :: Lang l => l a -> Pat l a f res -> l f -> l (Maybe res)
patL a p l =
case p of
LitPat i -> _if (_eq a i) (_just l) _nothing
VarPat -> _just (_app l a)
NilPat -> _list a (_lam (\_ -> (_lam (\_ -> _nothing)))) (_just l)
ConsPat c1 c2 ->
_list a (_lam (\x -> (_lam (\xs -> _maybe (patL x c1 l)
(_lam (\r1 -> patL xs c2 r1) )
(_nothing)))))
(_nothing)
-- Defining interpreters
unstaged :: a -> [Case Identity a b] -> b
unstaged a xs = runIdentity (matchL (pure a) xs)
staged :: Code a -> [Case Code a b] -> Code b
staged = matchL
staged_beta :: Code a -> [Case Code2 a b] -> Code b
staged_beta x = i . matchL (C x)
-- With failure continuation
matchL2 :: Lang l => l a -> [Case l a b] -> l b
matchL2 _ [] = _error "no matches"
matchL2 c (Case p l : xs) =
_let (matchL2 c xs)
(_lam $ \j -> patL2 c p l id j)
patL2 :: Lang l => l a -> Pat l a f r1 -> l f -> (l r1 -> l res) -> l res -> l res
patL2 a p l k1 k2 =
case p of
LitPat i -> _if (_eq a i) (k1 l) k2
VarPat -> k1 (_app l a)
NilPat -> _list a (_lam (\_ -> (_lam (\_ -> k2)))) (k1 l)
ConsPat c1 c2 ->
_list a (_lam (\x -> (_lam (\xs -> patL2 x c1 l (\r1 -> patL2 xs c2 r1 k1 k2) k2))))
k2
unstagedCPS :: a -> [Case Identity a b] -> b
unstagedCPS a xs = runIdentity (matchL2 (pure a) xs)
stagedCPS :: Code a -> [Case Code a b] -> Code b
stagedCPS = matchL2
staged_betaCPS :: Code a -> [Case Code2 a b] -> Code b
staged_betaCPS x = i . matchL2 (C x)
-- Exercise, define a three-stage program where the first stage specialises the interpreter
-- based on the language (not using type classes)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment