Created
August 30, 2024 09:03
-
-
Save mpickering/8ce7a3017113d32d0342718e298d3029 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TemplateHaskellQuotes #-} | |
{-# LANGUAGE GADTs #-} | |
module Pat3 where | |
import qualified Language.Haskell.TH as TH (Code, Q,) | |
import qualified Language.Haskell.TH.Syntax as TH (Lift(..)) | |
import Data.Functor.Identity | |
type Code = TH.Code TH.Q | |
data Pat m input f res where | |
LitPat :: m Int -> Pat m Int res res | |
-- Do not use Code here | |
VarPat :: Pat m a (a -> r) r | |
NilPat :: Pat m [a] r r | |
ConsPat :: Pat m a f g -> Pat m [a] g r -> Pat m [a] f r | |
data Case m input res where | |
Case :: | |
-- Do not use Code here | |
Pat m input f r | |
-- Wrap RHS in m | |
-> m f -> Case m input r | |
v1 :: Pat Code [a] (a -> [a] -> r) r | |
v1 = VarPat `ConsPat` VarPat | |
defn1 :: Case Code [a] (Maybe (a,[a])) | |
defn1 = Case v1 [|| (\(a1) a2 -> (Just (a1, a2))) ||] | |
defn2 :: Case Code [a] (Maybe x) | |
defn2 = Case NilPat ([|| Nothing ||]) | |
defn3 :: Case Code [Int] (Maybe (Int, [Int])) | |
defn3 = Case (LitPat [|| 1 ||] `ConsPat` VarPat) ([|| (\a2 -> Just (1, a2)) ||]) | |
match :: Code a -> [Case Code a b] -> Code b | |
match a b = match_loop a b | |
match_loop :: Code a -> [Case Code a b] -> Code b | |
match_loop _ [] = [|| error "no matches" ||] | |
match_loop c (Case p l : xs) = | |
[|| case $$(pat c p l) of | |
Just res -> res | |
Nothing -> $$(match_loop c xs) | |
||] | |
pat :: Code a -> Pat Code a f res -> Code f -> Code (Maybe res) | |
pat a p l = | |
case p of | |
LitPat i -> [|| case $$a of | |
x | x == $$i -> Just $$l | |
_ -> Nothing ||] | |
VarPat -> [|| case $$a of | |
x -> Just ($$l x) ||] | |
NilPat -> [|| case $$a of | |
[] -> Just $$l | |
_ -> Nothing ||] | |
ConsPat c1 c2 -> | |
[|| case $$a of | |
(x:xs) -> | |
case $$(pat [|| x ||] c1 l) of | |
Just r1 -> $$(pat [|| xs ||] c2 [|| r1 ||]) | |
Nothing -> Nothing | |
[] -> Nothing ||] | |
-- A staged and unstaged interpreter | |
class Lang l where | |
_if :: l Bool -> l a -> l a -> l a | |
_eq :: l Int -> l Int -> l Bool | |
_int :: Int -> l Int | |
_maybe :: l (Maybe a) -> l (a -> b) -> l b -> l b | |
_list :: l [a] -> l (a -> [a] -> r) -> l r -> l r | |
_tup :: l a -> l b -> l (a, b) | |
_cons :: l a -> l [a] -> l [a] | |
_nil :: l [a] | |
_just :: l a -> l (Maybe a) | |
_nothing :: l (Maybe a) | |
_app :: l (a -> b) -> l a -> l b | |
_lam :: (l a -> l b) -> l (a -> b) | |
_error :: String -> l a | |
_let :: l a -> l (a -> b) -> l b | |
instance Lang Identity where | |
_if b t f = if runIdentity b then t else f | |
_eq e1 e2 = Identity (runIdentity e1 == runIdentity e2) | |
_int i = Identity i | |
_maybe m j n = case runIdentity m of | |
Just x -> Identity (runIdentity j x) | |
Nothing -> n | |
_list l cons nil = case runIdentity l of | |
(x:xs) -> Identity (runIdentity cons x xs) | |
[] -> nil | |
_tup l r = Identity (runIdentity l, runIdentity r) | |
_cons x xs = (:) <$> x <*> xs | |
_nil = pure [] | |
_just = fmap Just | |
_nothing = pure Nothing | |
_app = (<*>) | |
_lam f = Identity (\a -> runIdentity (f (Identity a))) | |
_error x = Identity (error x) | |
_let l f = Identity (let j = runIdentity l in runIdentity f j) | |
instance Lang Code where | |
_if b t f = [|| if $$b then $$t else $$f ||] | |
_eq e1 e2 = [|| $$e1 == $$e2 ||] | |
_int i = TH.liftTyped i | |
_maybe m j n = [|| case $$m of | |
Just x -> $$j x | |
Nothing -> $$n ||] | |
_list l cons nil = [|| case $$l of | |
(x:xs) -> $$cons x xs | |
[] -> $$nil ||] | |
_tup l r = [|| ($$l, $$r) ||] | |
_cons x xs = [|| $$x : $$xs ||] | |
_nil = [|| [] ||] | |
_just x = [|| Just $$x ||] | |
_nothing = [|| Nothing ||] | |
_app f x = [|| $$f $$x ||] | |
_lam f = [|| \x -> $$(f [|| x ||]) ||] | |
_error x = [|| error $$(TH.liftTyped x) ||] | |
_let l f = [|| let j = $$l in $$f j ||] | |
-- Functions in target language are represented by functions in the | |
-- host language. | |
data Code2 t where | |
C :: Code t -> Code2 t | |
F :: (Code2 t -> Code2 r) -> Code2 (t -> r) | |
-- Interpretation function | |
i :: Code2 t -> Code t | |
i (C c1) = c1 | |
i (F f) = [|| \x -> $$(i (f (C [|| x ||]))) ||] | |
-- Application | |
a :: Code2 (a -> b) -> Code2 a -> Code2 b | |
a (C f) x = C [|| $$f $$(i x) ||] | |
a (F f) x = f x | |
a2 :: Code2 (a -> b -> c) -> Code2 a -> Code2 b -> Code2 c | |
a2 c2 x y = a (a c2 x) y | |
instance Lang Code2 where | |
_if b t f = C [|| if $$(i b) then $$(i t) else $$(i f) ||] | |
_eq e1 e2 = C [|| $$(i e1) == $$(i e2) ||] | |
_int i = C (TH.liftTyped i) | |
_maybe m j n = C [|| case $$(i m) of | |
Just x -> $$(i $ a j (C [|| x ||])) | |
Nothing -> $$(i n) ||] | |
_list l cons nil = C [|| case $$(i l) of | |
(x:xs) -> $$(i $ a2 cons (C [|| x ||]) (C [||xs ||])) | |
[] -> $$(i nil) ||] | |
_tup l r = C [|| ($$(i l), $$(i r)) ||] | |
_cons x xs = C [|| $$(i x) : $$(i xs) ||] | |
_nil = C [|| [] ||] | |
_just x = C [|| Just $$(i x) ||] | |
_nothing = C [|| Nothing ||] | |
_app x y = a x y | |
_lam f = F f | |
_error x = C [|| error $$(TH.liftTyped x) ||] | |
_let l b = C [|| let j = $$(i l) in $$(i $ a b (C [|| j ||])) ||] | |
-- Testing that nesting lambdas works | |
test_reduce :: Lang l => l Bool | |
test_reduce =(_lam (\x -> _lam (\y -> _eq x y))) `_app` _int 0 `_app` _int 1 | |
v1L :: Pat l [a] (a -> [a] -> r) r | |
v1L = VarPat `ConsPat` VarPat | |
defn1L :: Lang l => Case l [a] (Maybe (a,[a])) | |
defn1L = Case v1L (_lam (\x1 -> _lam (\x2 -> _just (_tup x1 x2)))) | |
defn2L :: Lang l => Case l [a] (Maybe x) | |
defn2L = Case NilPat _nothing | |
defn3L :: Lang l => Case l [Int] (Maybe (Int, [Int])) | |
defn3L = Case (LitPat (_int 1) `ConsPat` VarPat) (_lam (\a2 -> _just (_tup (_int 1) a2))) | |
matchL :: Lang l => l a -> [Case l a b] -> l b | |
matchL _ [] = _error "no matches" | |
matchL c (Case p l : xs) = | |
_maybe (patL c p l) (_lam id) (matchL c xs) | |
patL :: Lang l => l a -> Pat l a f res -> l f -> l (Maybe res) | |
patL a p l = | |
case p of | |
LitPat i -> _if (_eq a i) (_just l) _nothing | |
VarPat -> _just (_app l a) | |
NilPat -> _list a (_lam (\_ -> (_lam (\_ -> _nothing)))) (_just l) | |
ConsPat c1 c2 -> | |
_list a (_lam (\x -> (_lam (\xs -> _maybe (patL x c1 l) | |
(_lam (\r1 -> patL xs c2 r1) ) | |
(_nothing))))) | |
(_nothing) | |
-- Defining interpreters | |
unstaged :: a -> [Case Identity a b] -> b | |
unstaged a xs = runIdentity (matchL (pure a) xs) | |
staged :: Code a -> [Case Code a b] -> Code b | |
staged = matchL | |
staged_beta :: Code a -> [Case Code2 a b] -> Code b | |
staged_beta x = i . matchL (C x) | |
-- With failure continuation | |
matchL2 :: Lang l => l a -> [Case l a b] -> l b | |
matchL2 _ [] = _error "no matches" | |
matchL2 c (Case p l : xs) = | |
_let (matchL2 c xs) | |
(_lam $ \j -> patL2 c p l id j) | |
patL2 :: Lang l => l a -> Pat l a f r1 -> l f -> (l r1 -> l res) -> l res -> l res | |
patL2 a p l k1 k2 = | |
case p of | |
LitPat i -> _if (_eq a i) (k1 l) k2 | |
VarPat -> k1 (_app l a) | |
NilPat -> _list a (_lam (\_ -> (_lam (\_ -> k2)))) (k1 l) | |
ConsPat c1 c2 -> | |
_list a (_lam (\x -> (_lam (\xs -> patL2 x c1 l (\r1 -> patL2 xs c2 r1 k1 k2) k2)))) | |
k2 | |
unstagedCPS :: a -> [Case Identity a b] -> b | |
unstagedCPS a xs = runIdentity (matchL2 (pure a) xs) | |
stagedCPS :: Code a -> [Case Code a b] -> Code b | |
stagedCPS = matchL2 | |
staged_betaCPS :: Code a -> [Case Code2 a b] -> Code b | |
staged_betaCPS x = i . matchL2 (C x) | |
-- Exercise, define a three-stage program where the first stage specialises the interpreter | |
-- based on the language (not using type classes) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment