Created
August 29, 2024 12:20
-
-
Save mpickering/46123dece82199abf79724e3b6d6ae0b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TemplateHaskellQuotes #-} | |
{-# LANGUAGE GADTs #-} | |
module Pat3 where | |
import qualified Language.Haskell.TH as TH (Code, Q,) | |
import qualified Language.Haskell.TH.Syntax as TH (Lift(..)) | |
import Data.Functor.Identity | |
type Code = TH.Code TH.Q | |
data Pat m input f res where | |
LitPat :: m Int -> Pat m Int res res | |
-- Do not use Code here | |
VarPat :: Pat m a (a -> r) r | |
NilPat :: Pat m [a] r r | |
ConsPat :: Pat m a f g -> Pat m [a] g r -> Pat m [a] f r | |
data Case m input res where | |
Case :: | |
-- Do not use Code here | |
Pat m input f r | |
-- Wrap RHS in m | |
-> m f -> Case m input r | |
v1 :: Pat Code [a] (a -> [a] -> r) r | |
v1 = VarPat `ConsPat` VarPat | |
defn1 :: Case Code [a] (Maybe (a,[a])) | |
defn1 = Case v1 [|| (\(a1) a2 -> (Just (a1, a2))) ||] | |
defn2 :: Case Code [a] (Maybe x) | |
defn2 = Case NilPat ([|| Nothing ||]) | |
defn3 :: Case Code [Int] (Maybe (Int, [Int])) | |
defn3 = Case (LitPat [|| 1 ||] `ConsPat` VarPat) ([|| (\a2 -> Just (1, a2)) ||]) | |
match :: Code a -> [Case Code a b] -> Code b | |
match a b = match_loop a b | |
match_loop :: Code a -> [Case Code a b] -> Code b | |
match_loop _ [] = [|| error "no matches" ||] | |
match_loop c (Case p l : xs) = | |
[|| case $$(pat c p l) of | |
Just res -> res | |
Nothing -> $$(match_loop c xs) | |
||] | |
pat :: Code a -> Pat Code a f res -> Code f -> Code (Maybe res) | |
pat a p l = | |
case p of | |
LitPat i -> [|| case $$a of | |
x | x == $$i -> Just $$l | |
_ -> Nothing ||] | |
VarPat -> [|| case $$a of | |
x -> Just ($$l x) ||] | |
NilPat -> [|| case $$a of | |
[] -> Just $$l | |
_ -> Nothing ||] | |
ConsPat c1 c2 -> | |
[|| case $$a of | |
(x:xs) -> | |
case $$(pat [|| x ||] c1 l) of | |
Just r1 -> $$(pat [|| xs ||] c2 [|| r1 ||]) | |
Nothing -> Nothing | |
[] -> Nothing ||] | |
-- A staged and unstaged interpreter | |
class Lang l where | |
_if :: l Bool -> l a -> l a -> l a | |
_eq :: l Int -> l Int -> l Bool | |
_int :: Int -> l Int | |
_maybe :: l (Maybe a) -> l (a -> b) -> l b -> l b | |
_list :: l [a] -> l (a -> [a] -> r) -> l r -> l r | |
_tup :: l a -> l b -> l (a, b) | |
_cons :: l a -> l [a] -> l [a] | |
_nil :: l [a] | |
_just :: l a -> l (Maybe a) | |
_nothing :: l (Maybe a) | |
_app :: l (a -> b) -> l a -> l b | |
_lam :: (l a -> l b) -> l (a -> b) | |
_error :: String -> l a | |
instance Lang Identity where | |
_if b t f = if runIdentity b then t else f | |
_eq e1 e2 = Identity (runIdentity e1 == runIdentity e2) | |
_int i = Identity i | |
_maybe m j n = case runIdentity m of | |
Just x -> Identity (runIdentity j x) | |
Nothing -> n | |
_list l cons nil = case runIdentity l of | |
(x:xs) -> Identity (runIdentity cons x xs) | |
[] -> nil | |
_tup l r = Identity (runIdentity l, runIdentity r) | |
_cons x xs = (:) <$> x <*> xs | |
_nil = pure [] | |
_just = fmap Just | |
_nothing = pure Nothing | |
_app = (<*>) | |
_lam f = Identity (\a -> runIdentity (f (Identity a))) | |
_error x = Identity (error x) | |
instance Lang Code where | |
_if b t f = [|| if $$b then $$t else $$f ||] | |
_eq e1 e2 = [|| $$e1 == $$e2 ||] | |
_int i = TH.liftTyped i | |
_maybe m j n = [|| case $$m of | |
Just x -> $$j x | |
Nothing -> $$n ||] | |
_list l cons nil = [|| case $$l of | |
(x:xs) -> $$cons x xs | |
[] -> $$nil ||] | |
_tup l r = [|| ($$l, $$r) ||] | |
_cons x xs = [|| $$x : $$xs ||] | |
_nil = [|| [] ||] | |
_just x = [|| Just $$x ||] | |
_nothing = [|| Nothing ||] | |
_app f x = [|| $$f $$x ||] | |
_lam f = [|| \x -> $$(f [|| x ||]) ||] | |
_error x = [|| error $$(TH.liftTyped x) ||] | |
v1L :: Pat l [a] (a -> [a] -> r) r | |
v1L = VarPat `ConsPat` VarPat | |
defn1L :: Lang l => Case l [a] (Maybe (a,[a])) | |
defn1L = Case v1L (_lam (\x1 -> _lam (\x2 -> _just (_tup x1 x2)))) | |
defn2L :: Lang l => Case l [a] (Maybe x) | |
defn2L = Case NilPat _nothing | |
defn3L :: Lang l => Case l [Int] (Maybe (Int, [Int])) | |
defn3L = Case (LitPat (_int 1) `ConsPat` VarPat) (_lam (\a2 -> _just (_tup (_int 1) a2))) | |
matchL :: Lang l => l a -> [Case l a b] -> l b | |
matchL _ [] = _error "no matches" | |
matchL c (Case p l : xs) = | |
_maybe (patL c p l) (_lam id) (matchL c xs) | |
patL :: Lang l => l a -> Pat l a f res -> l f -> l (Maybe res) | |
patL a p l = | |
case p of | |
LitPat i -> _if (_eq a i) (_just l) _nothing | |
VarPat -> _just (_app l a) | |
NilPat -> _list a (_lam (\_ -> (_lam (\_ -> _nothing)))) (_just l) | |
ConsPat c1 c2 -> | |
_list a (_lam (\x -> (_lam (\xs -> _maybe (patL x c1 l) | |
(_lam (\r1 -> patL xs c2 r1) ) | |
(_nothing))))) | |
(_nothing) | |
-- Defining interpreters | |
unstaged :: a -> [Case Identity a b] -> b | |
unstaged a xs = runIdentity (matchL (pure a) xs) | |
staged :: Code a -> [Case Code a b] -> Code b | |
staged = matchL | |
-- Exercise, define a three-stage program where the first stage specialises the interpreter | |
-- based on the language (not using type classes) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE TemplateHaskellQuotes #-} | |
module Pat where | |
import qualified Language.Haskell.TH as TH (Code, Q,) | |
import qualified Language.Haskell.TH.Syntax as TH (Lift(..)) | |
-- Step 1: Unstaged interpreter | |
data Pat input f res where | |
LitPat :: Int -> Pat Int res res | |
VarPat :: Pat a (a -> r) r | |
NilPat :: Pat [a] r r | |
ConsPat :: Pat a f g -> Pat [a] g r -> Pat [a] f r | |
data Case input res where | |
Case :: | |
Pat input f r | |
-> f -> Case input r | |
v1 :: Pat [a] (a -> [a] -> r) r | |
v1 = VarPat `ConsPat` VarPat | |
defn1 :: Case [a] (Maybe (a,[a])) | |
defn1 = Case v1 (\a1 a2 -> Just (a1, a2)) | |
defn2 :: Case [a] (Maybe x) | |
defn2 = Case NilPat Nothing | |
defn3 :: Case [Int] (Maybe (Int, [Int])) | |
defn3 = Case (LitPat 1 `ConsPat` VarPat) (\a2 -> Just (1, a2)) | |
match :: a -> [Case a b] -> b | |
match _ [] = error "no matches" | |
match c (Case p l : xs) = | |
case pat c p l of | |
Just x -> x | |
Nothing -> match c xs | |
pat :: a -> Pat a f r -> f -> Maybe r | |
pat a (LitPat n) r | n == a = Just r | |
| otherwise = Nothing | |
pat a VarPat r = Just (r a) | |
pat a NilPat r = case a of | |
[] -> Just r | |
_ -> Nothing | |
pat a (ConsPat p1 p2) r = case a of | |
(x:xs) -> pat xs p2 =<< pat x p1 r | |
[] -> Nothing | |
-- Challenge, write a staged interpreter for `Case` | |
-- This attempt follows the descrption in the paper: | |
-- Safe Pattern Generation for Multi-Stage Programming | |
-- https://www.cl.cam.ac.uk/~jdy22/papers/safe-pattern-generation-for-multi-stage-programming.pdf | |
-- | |
-- Clue: Using this definition it is impossible to write a well-typed staged interpreter. | |
type Code = TH.Code TH.Q | |
data PatS input f res where | |
LitPatS :: Int -> PatS Int res res | |
VarPatS :: PatS a (Code a -> r) r | |
NilPatS :: PatS [a] r r | |
ConsPatS :: PatS a f g -> PatS [a] g r -> PatS [a] f r | |
data CaseS input res where | |
CaseS :: | |
PatS input f (Code r) | |
-> f -> CaseS input r | |
v1s :: PatS [a] (Code a -> Code [a] -> r) r | |
v1s = VarPatS `ConsPatS` VarPatS | |
defn1s :: CaseS [a] (Maybe (a,[a])) | |
defn1s = CaseS v1s (\a1 a2 -> [|| Just ($$a1, $$a2) ||] ) | |
defn2s :: CaseS [a] (Maybe x) | |
defn2s = CaseS NilPatS ([|| Nothing ||]) | |
defn3s :: CaseS [Int] (Maybe (Int, [Int])) | |
defn3s = CaseS (LitPatS 1 `ConsPatS` VarPatS) (\a2 -> [|| Just (1, $$a2) ||]) | |
-- Challenge, write a staged interpreter for `Case` | |
matchS :: Code a -> [CaseS a b] -> Code b | |
matchS _ [] = [|| error "no matches" ||] | |
matchS c (CaseS p l : xs) = | |
[|| case $$(patS c p l) of | |
Just r -> r | |
||] | |
patS :: Code a -> PatS a f (Code res) -> f -> Code (Maybe res) | |
patS a p l = | |
case p of | |
LitPatS i -> [|| if $$a == i then Just $$l else Nothing ||] | |
VarPatS -> [|| Just $$(l a) ||] | |
NilPatS -> [|| case $$a of | |
[] -> Just $$l | |
(x:xs) -> Nothing ||] | |
ConsPatS c1 c2 -> [|| case $$a of | |
[] -> Nothing | |
(x:xs) -> | |
-- This is impossible because c1 ~ PatS a f g | |
-- so we can't recursively call `patS` because `g` can't | |
-- unify with `Code res`. | |
$$(patS [|| x ||] c1 l) ||] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TemplateHaskell #-} | |
module Pat_Test where | |
import Pat3 | |
pt x = $$(staged [|| x ||] [defn3L, defn1L, defn2L, defn3L]) | |
--ptU x = unstaged x [defn3L, defn1L, defn2L, defn3L] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment