Skip to content

Instantly share code, notes, and snippets.

@mpickering
Created August 29, 2024 12:20
Show Gist options
  • Save mpickering/46123dece82199abf79724e3b6d6ae0b to your computer and use it in GitHub Desktop.
Save mpickering/46123dece82199abf79724e3b6d6ae0b to your computer and use it in GitHub Desktop.
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE GADTs #-}
module Pat3 where
import qualified Language.Haskell.TH as TH (Code, Q,)
import qualified Language.Haskell.TH.Syntax as TH (Lift(..))
import Data.Functor.Identity
type Code = TH.Code TH.Q
data Pat m input f res where
LitPat :: m Int -> Pat m Int res res
-- Do not use Code here
VarPat :: Pat m a (a -> r) r
NilPat :: Pat m [a] r r
ConsPat :: Pat m a f g -> Pat m [a] g r -> Pat m [a] f r
data Case m input res where
Case ::
-- Do not use Code here
Pat m input f r
-- Wrap RHS in m
-> m f -> Case m input r
v1 :: Pat Code [a] (a -> [a] -> r) r
v1 = VarPat `ConsPat` VarPat
defn1 :: Case Code [a] (Maybe (a,[a]))
defn1 = Case v1 [|| (\(a1) a2 -> (Just (a1, a2))) ||]
defn2 :: Case Code [a] (Maybe x)
defn2 = Case NilPat ([|| Nothing ||])
defn3 :: Case Code [Int] (Maybe (Int, [Int]))
defn3 = Case (LitPat [|| 1 ||] `ConsPat` VarPat) ([|| (\a2 -> Just (1, a2)) ||])
match :: Code a -> [Case Code a b] -> Code b
match a b = match_loop a b
match_loop :: Code a -> [Case Code a b] -> Code b
match_loop _ [] = [|| error "no matches" ||]
match_loop c (Case p l : xs) =
[|| case $$(pat c p l) of
Just res -> res
Nothing -> $$(match_loop c xs)
||]
pat :: Code a -> Pat Code a f res -> Code f -> Code (Maybe res)
pat a p l =
case p of
LitPat i -> [|| case $$a of
x | x == $$i -> Just $$l
_ -> Nothing ||]
VarPat -> [|| case $$a of
x -> Just ($$l x) ||]
NilPat -> [|| case $$a of
[] -> Just $$l
_ -> Nothing ||]
ConsPat c1 c2 ->
[|| case $$a of
(x:xs) ->
case $$(pat [|| x ||] c1 l) of
Just r1 -> $$(pat [|| xs ||] c2 [|| r1 ||])
Nothing -> Nothing
[] -> Nothing ||]
-- A staged and unstaged interpreter
class Lang l where
_if :: l Bool -> l a -> l a -> l a
_eq :: l Int -> l Int -> l Bool
_int :: Int -> l Int
_maybe :: l (Maybe a) -> l (a -> b) -> l b -> l b
_list :: l [a] -> l (a -> [a] -> r) -> l r -> l r
_tup :: l a -> l b -> l (a, b)
_cons :: l a -> l [a] -> l [a]
_nil :: l [a]
_just :: l a -> l (Maybe a)
_nothing :: l (Maybe a)
_app :: l (a -> b) -> l a -> l b
_lam :: (l a -> l b) -> l (a -> b)
_error :: String -> l a
instance Lang Identity where
_if b t f = if runIdentity b then t else f
_eq e1 e2 = Identity (runIdentity e1 == runIdentity e2)
_int i = Identity i
_maybe m j n = case runIdentity m of
Just x -> Identity (runIdentity j x)
Nothing -> n
_list l cons nil = case runIdentity l of
(x:xs) -> Identity (runIdentity cons x xs)
[] -> nil
_tup l r = Identity (runIdentity l, runIdentity r)
_cons x xs = (:) <$> x <*> xs
_nil = pure []
_just = fmap Just
_nothing = pure Nothing
_app = (<*>)
_lam f = Identity (\a -> runIdentity (f (Identity a)))
_error x = Identity (error x)
instance Lang Code where
_if b t f = [|| if $$b then $$t else $$f ||]
_eq e1 e2 = [|| $$e1 == $$e2 ||]
_int i = TH.liftTyped i
_maybe m j n = [|| case $$m of
Just x -> $$j x
Nothing -> $$n ||]
_list l cons nil = [|| case $$l of
(x:xs) -> $$cons x xs
[] -> $$nil ||]
_tup l r = [|| ($$l, $$r) ||]
_cons x xs = [|| $$x : $$xs ||]
_nil = [|| [] ||]
_just x = [|| Just $$x ||]
_nothing = [|| Nothing ||]
_app f x = [|| $$f $$x ||]
_lam f = [|| \x -> $$(f [|| x ||]) ||]
_error x = [|| error $$(TH.liftTyped x) ||]
v1L :: Pat l [a] (a -> [a] -> r) r
v1L = VarPat `ConsPat` VarPat
defn1L :: Lang l => Case l [a] (Maybe (a,[a]))
defn1L = Case v1L (_lam (\x1 -> _lam (\x2 -> _just (_tup x1 x2))))
defn2L :: Lang l => Case l [a] (Maybe x)
defn2L = Case NilPat _nothing
defn3L :: Lang l => Case l [Int] (Maybe (Int, [Int]))
defn3L = Case (LitPat (_int 1) `ConsPat` VarPat) (_lam (\a2 -> _just (_tup (_int 1) a2)))
matchL :: Lang l => l a -> [Case l a b] -> l b
matchL _ [] = _error "no matches"
matchL c (Case p l : xs) =
_maybe (patL c p l) (_lam id) (matchL c xs)
patL :: Lang l => l a -> Pat l a f res -> l f -> l (Maybe res)
patL a p l =
case p of
LitPat i -> _if (_eq a i) (_just l) _nothing
VarPat -> _just (_app l a)
NilPat -> _list a (_lam (\_ -> (_lam (\_ -> _nothing)))) (_just l)
ConsPat c1 c2 ->
_list a (_lam (\x -> (_lam (\xs -> _maybe (patL x c1 l)
(_lam (\r1 -> patL xs c2 r1) )
(_nothing)))))
(_nothing)
-- Defining interpreters
unstaged :: a -> [Case Identity a b] -> b
unstaged a xs = runIdentity (matchL (pure a) xs)
staged :: Code a -> [Case Code a b] -> Code b
staged = matchL
-- Exercise, define a three-stage program where the first stage specialises the interpreter
-- based on the language (not using type classes)
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
module Pat where
import qualified Language.Haskell.TH as TH (Code, Q,)
import qualified Language.Haskell.TH.Syntax as TH (Lift(..))
-- Step 1: Unstaged interpreter
data Pat input f res where
LitPat :: Int -> Pat Int res res
VarPat :: Pat a (a -> r) r
NilPat :: Pat [a] r r
ConsPat :: Pat a f g -> Pat [a] g r -> Pat [a] f r
data Case input res where
Case ::
Pat input f r
-> f -> Case input r
v1 :: Pat [a] (a -> [a] -> r) r
v1 = VarPat `ConsPat` VarPat
defn1 :: Case [a] (Maybe (a,[a]))
defn1 = Case v1 (\a1 a2 -> Just (a1, a2))
defn2 :: Case [a] (Maybe x)
defn2 = Case NilPat Nothing
defn3 :: Case [Int] (Maybe (Int, [Int]))
defn3 = Case (LitPat 1 `ConsPat` VarPat) (\a2 -> Just (1, a2))
match :: a -> [Case a b] -> b
match _ [] = error "no matches"
match c (Case p l : xs) =
case pat c p l of
Just x -> x
Nothing -> match c xs
pat :: a -> Pat a f r -> f -> Maybe r
pat a (LitPat n) r | n == a = Just r
| otherwise = Nothing
pat a VarPat r = Just (r a)
pat a NilPat r = case a of
[] -> Just r
_ -> Nothing
pat a (ConsPat p1 p2) r = case a of
(x:xs) -> pat xs p2 =<< pat x p1 r
[] -> Nothing
-- Challenge, write a staged interpreter for `Case`
-- This attempt follows the descrption in the paper:
-- Safe Pattern Generation for Multi-Stage Programming
-- https://www.cl.cam.ac.uk/~jdy22/papers/safe-pattern-generation-for-multi-stage-programming.pdf
--
-- Clue: Using this definition it is impossible to write a well-typed staged interpreter.
type Code = TH.Code TH.Q
data PatS input f res where
LitPatS :: Int -> PatS Int res res
VarPatS :: PatS a (Code a -> r) r
NilPatS :: PatS [a] r r
ConsPatS :: PatS a f g -> PatS [a] g r -> PatS [a] f r
data CaseS input res where
CaseS ::
PatS input f (Code r)
-> f -> CaseS input r
v1s :: PatS [a] (Code a -> Code [a] -> r) r
v1s = VarPatS `ConsPatS` VarPatS
defn1s :: CaseS [a] (Maybe (a,[a]))
defn1s = CaseS v1s (\a1 a2 -> [|| Just ($$a1, $$a2) ||] )
defn2s :: CaseS [a] (Maybe x)
defn2s = CaseS NilPatS ([|| Nothing ||])
defn3s :: CaseS [Int] (Maybe (Int, [Int]))
defn3s = CaseS (LitPatS 1 `ConsPatS` VarPatS) (\a2 -> [|| Just (1, $$a2) ||])
-- Challenge, write a staged interpreter for `Case`
matchS :: Code a -> [CaseS a b] -> Code b
matchS _ [] = [|| error "no matches" ||]
matchS c (CaseS p l : xs) =
[|| case $$(patS c p l) of
Just r -> r
||]
patS :: Code a -> PatS a f (Code res) -> f -> Code (Maybe res)
patS a p l =
case p of
LitPatS i -> [|| if $$a == i then Just $$l else Nothing ||]
VarPatS -> [|| Just $$(l a) ||]
NilPatS -> [|| case $$a of
[] -> Just $$l
(x:xs) -> Nothing ||]
ConsPatS c1 c2 -> [|| case $$a of
[] -> Nothing
(x:xs) ->
-- This is impossible because c1 ~ PatS a f g
-- so we can't recursively call `patS` because `g` can't
-- unify with `Code res`.
$$(patS [|| x ||] c1 l) ||]
{-# LANGUAGE TemplateHaskell #-}
module Pat_Test where
import Pat3
pt x = $$(staged [|| x ||] [defn3L, defn1L, defn2L, defn3L])
--ptU x = unstaged x [defn3L, defn1L, defn2L, defn3L]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment