Created
April 19, 2018 16:54
-
-
Save mpickering/b6935a63717129cc33b77774c126168a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TypeApplications #-} | |
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE DeriveGeneric #-} | |
{-# LANGUAGE TemplateHaskell #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE RankNTypes #-} | |
module LMS where | |
import Data.Generics.Product | |
import GHC.Generics | |
import Control.Lens (view) | |
import Debug.Trace | |
import Language.Haskell.TH | |
import Language.Haskell.TH.Syntax | |
import Unsafe.Coerce | |
{- | |
The essence of LMS is | |
1) Virtualised syntax | |
2) Implicit lifting | |
Here is how we might embed it in Haskell | |
-} | |
class Ops r where | |
_if :: r Bool -> r a -> r a -> r a | |
_app :: r (a -> b) -> r a -> r b | |
_lam :: (r a -> r b) -> r (a -> b) | |
_lam2 :: (r (Int -> Int) -> r Int -> r Int) -> (r (Int -> Int) -> r Int) -> r Int | |
_int :: Int -> r Int | |
_string :: String -> r String | |
_mul :: r Int -> r Int -> r Int | |
_mod :: r Int -> r Int -> r Int | |
_div :: r Int -> r Int -> r Int | |
_minus :: r Int -> r Int -> r Int | |
_plus :: r Int -> r Int -> r Int | |
_eq :: Eq a => r a -> r a -> r Bool | |
_let :: r a -> (r a -> r b) -> r b | |
_fix :: (r a -> r a) -> r a | |
_quote :: Lift a => a -> r a | |
liftInt :: Ops r => Identity Int -> r Int | |
liftInt = _int . runIdentity | |
data Syn a where | |
If_ :: Syn Bool -> Syn a -> Syn a -> Syn a | |
App_ :: Syn (a -> b) -> Syn a -> Syn b | |
Lam_ :: (Syn a -> Syn b) -> Syn (a -> b) | |
Lam2_ :: (Syn (Int -> Int) -> Syn Int -> Syn Int) -> (Syn (Int -> Int) -> Syn Int) -> Syn Int | |
Int_ :: Int -> Syn Int | |
String_ :: String -> Syn String | |
Mul_ :: Syn Int -> Syn Int -> Syn Int | |
Plus_ :: Syn Int -> Syn Int -> Syn Int | |
Minus_ :: Syn Int -> Syn Int -> Syn Int | |
Div_ :: Syn Int -> Syn Int -> Syn Int | |
Mod_ :: Syn Int -> Syn Int -> Syn Int | |
Eq_ :: Eq a => Syn a -> Syn a -> Syn Bool | |
Let_ :: Syn a -> (Syn a -> Syn b) -> Syn b | |
Fix_ :: (Syn a -> Syn a) -> Syn a | |
Quote_ :: a -> Syn a | |
newtype Identity a = Identity a deriving (Generic, Show) | |
instance Functor Identity where | |
fmap f (Identity a) = Identity (f a) | |
instance Applicative Identity where | |
pure = Identity | |
(Identity f) <*> (Identity x) = Identity (f x) | |
runIdentity :: Identity a -> a | |
runIdentity = view (the @1) | |
unwrap :: TTExp a -> Q (TExp a) | |
unwrap = view (the @1) | |
instance Ops Identity where | |
_if (Identity b) e1 e2 = Identity (if b then runIdentity e1 else runIdentity e2) | |
_app (Identity f) (Identity a) = Identity (f a) | |
_lam f = Identity (\a -> runIdentity (f (Identity a))) | |
_lam2 f1 f2 = f2 (Identity f) | |
where | |
f :: Int -> Int | |
f x = runIdentity $ f1 (Identity f) (Identity x) | |
_int n = Identity n | |
_string s = Identity s | |
_mul (Identity n1) (Identity n2) = Identity (n1 * n2) | |
_plus (Identity n1) (Identity n2) = Identity (n1 + n2) | |
_minus (Identity n1) (Identity n2) = Identity (n1 - n2) | |
_eq (Identity n1) (Identity n2) = Identity (n1 == n2) | |
_mod n1 n2 = mod <$> n1 <*> n2 | |
_div n1 n2 = div <$> n1 <*> n2 | |
_let l f = (f l) | |
_fix f = (fix f) | |
_quote a = Identity a | |
instance Ops Syn where | |
_if = If_ | |
_app = App_ | |
_lam = Lam_ | |
_lam2 = Lam2_ | |
_int = Int_ | |
_string = String_ | |
_mul = Mul_ | |
_plus = Plus_ | |
_minus = Minus_ | |
_div = Div_ | |
_mod = Mod_ | |
_eq = Eq_ | |
_let = Let_ | |
_fix = Fix_ | |
_quote = Quote_ | |
newtype TTExp a = TTExp (Q (TExp a)) deriving Generic | |
--instance Functor TTExp where | |
-- fmap f (TTExp te) = TTExp $ [|| f $$(te) ||] | |
liftT :: Lift a => a -> TExpQ a | |
liftT = unsafeTExpCoerce . lift | |
fix f = let x = f x in x | |
instance Ops TTExp where | |
_if (TTExp te1) ~(TTExp te2) ~(TTExp te3) = TTExp $ | |
do | |
qRunIO $ putStrLn "IF" | |
[|| if $$te1 then $$te2 else $$te3 ||] | |
_app ~(TTExp te1) ~(TTExp te2) = TTExp $ [|| $$te1 $$te2 ||] | |
_lam f = TTExp $ [|| \a -> $$(unwrap $ f (TTExp [|| a ||])) ||] | |
_lam2 f1 f2 = TTExp $ [|| let f x = $$(unwrap $ f1 (TTExp $ [|| f ||]) (TTExp $ [|| x ||])) | |
in $$(unwrap $ f2 (TTExp [|| f ||])) ||] | |
_int n = TTExp $ liftT n | |
_string s = TTExp $ liftT s | |
_mul ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 * $$m2 ||] | |
_plus ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 + $$m2 ||] | |
_minus ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 - $$m2 ||] | |
_div ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 `div` $$m2 ||] | |
_mod ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 `mod` $$m2 ||] | |
_eq ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 == $$m2 ||] | |
_let m1 m2 = m2 m1 | |
_fix f = fix f | |
_quote x = TTExp $ [|| x ||] | |
-- Performs the transformation | |
-- if (if cond then 0 else 1) == 0 => if cond | |
data CollapseIf p a where | |
Dyn :: p a -> CollapseIf p a | |
Num :: Int -> CollapseIf p Int | |
CIf :: p Bool -> CollapseIf p Int | |
CEq :: p Bool -> CollapseIf p Bool | |
lowerCIf :: Ops p => CollapseIf p a -> p a | |
lowerCIf (Dyn pa) = pa | |
lowerCIf (Num n) = _int n | |
lowerCIf (CIf pb) = _if pb (_int 0) (_int 1) | |
lowerCIf (CEq pb) = (_if pb (_int 0) (_int 1)) `_eq` _int 0 | |
liftCIf :: p a -> CollapseIf p a | |
liftCIf = Dyn | |
instance Ops p => Ops (CollapseIf p) where | |
_if pa (Num 0) (Num 1) = CIf (lowerCIf pa) | |
_if (CEq pa) p1 p2 = Dyn (_if pa (lowerCIf p1) (lowerCIf p2)) | |
_if pa c1 c2 = Dyn (_if (lowerCIf pa) (lowerCIf c1) (lowerCIf c2)) | |
_int k = Num k | |
_eq (CIf pa) (Num 0) = CEq pa | |
_eq pa pb = Dyn (_eq (lowerCIf pa) (lowerCIf pb)) | |
_app pa pb = Dyn (_app (lowerCIf pa) (lowerCIf pb)) | |
_lam pa = Dyn (_lam (lowerCIf . pa . liftCIf )) | |
_lam2 cpii p = Dyn (_lam2 (\a1 a2 -> lowerCIf (cpii (liftCIf a1) (liftCIf a2))) | |
(lowerCIf . p . liftCIf)) | |
_string s = Dyn (_string s) | |
_mul pa pb = Dyn (_mul (lowerCIf pa) (lowerCIf pb)) | |
_mod pa pb = Dyn (_mod (lowerCIf pa) (lowerCIf pb)) | |
_div pa pb = Dyn (_div (lowerCIf pa) (lowerCIf pb)) | |
_minus pa pb = Dyn (_minus (lowerCIf pa) (lowerCIf pb)) | |
_plus pa pb = Dyn (_plus (lowerCIf pa) (lowerCIf pb)) | |
_let pa pab = (pab pa) | |
_fix = fix | |
_quote = Dyn . _quote | |
power :: Ops r => Int -> r Int -> r Int | |
power n k = | |
if n == 0 | |
then (_int 1) | |
else (_mul k (power (n-1) k)) | |
{- Wrong because we don't evaluate the condition statically, even though we | |
- can | |
_if ((_eq (_int n) (_int 0))) | |
(_int 1) | |
(_mul k (power (n-1) k)) | |
-} | |
_lamEq :: Ops r => r (Int -> Int -> Bool) | |
_lamEq = _lam (\ri -> _lam (\ri2 -> _eq ri ri2)) | |
_lamMul :: Ops r => r (Int -> Int -> Int) | |
_lamMul = _lam (\ri -> _lam (\ri2 -> _mul ri ri2)) | |
-- Even more overloading, application | |
power2 :: Ops r => Int -> r Int -> r Int | |
power2 n k = _if (_app (_app _lamEq (_int n)) (_int 0)) | |
(_int 1) | |
(_app (_app _lamMul k) (power2 (n-1) k)) | |
-- Even more overloading, abstraction | |
power3 :: Ops r => Int -> r (Int -> Int) | |
power3 n = _lam (\k -> _if (_app (_app _lamEq (_int n)) (_int 0)) | |
(_int 1) | |
(_app (_app _lamMul k) (power2 (n-1) k))) | |
test = powerIden 5 4 | |
powerIden k n = runIdentity (power @Identity k (Identity n)) | |
power2Iden k n = runIdentity (power2 @Identity k (Identity n)) | |
power3Iden k n = runIdentity (_app (power3 k) (_int n)) | |
power3Staged :: Int -> TTExp (Int -> Int) | |
power3Staged n = power3 n | |
power1Staged :: Int -> TTExp Int -> TTExp Int | |
power1Staged n k = power n k | |
powerSyn = power @Syn | |
power2Syn = power2 @Syn | |
data Expr = EInt Int | Var String | |
| App String Expr | |
| Add Expr Expr | |
| Sub Expr Expr | |
| Mul Expr Expr | |
| If Expr Expr Expr | |
| Eq Expr Expr | |
| Mod Expr Expr | |
| Div Expr Expr | |
data REnv r a = REnv (String -> r a) | |
getRenv :: Ops r => REnv r a -> String -> r a | |
getRenv (REnv r) a = r a | |
type Env r = REnv r Int | |
type FEnv r = REnv r (Int -> Int) | |
evalUnstaged :: Expr -> Env Identity -> FEnv Identity -> Identity Int | |
evalUnstaged (EInt n) _ _ = _int n | |
evalUnstaged (Var s) e _ = getRenv e s | |
evalUnstaged (App s e1) e fe = _app (getRenv fe s) (evalUnstaged e1 e fe) | |
evalUnstaged (Add e1 e2) e fe = evalUnstaged e1 e fe `_plus` evalUnstaged e2 e fe | |
evalUnstaged (Sub e1 e2) e fe = evalUnstaged e1 e fe `_minus` evalUnstaged e2 e fe | |
evalUnstaged (Mul e1 e2) e fe = (evalUnstaged e1 e fe `_mul` evalUnstaged e2 e fe) | |
evalUnstaged (If e1 e2 e3) env fe = | |
_if ((evalUnstaged e1 env fe) `_eq` _int 0) | |
(evalUnstaged e2 env fe) | |
(evalUnstaged e3 env fe) | |
evalStaged :: Expr -> EnvStaged -> FEnvStaged -> Q (TExp Int) | |
evalStaged (EInt n) _ _ = liftT n | |
evalStaged (Var s) e _ = e s | |
evalStaged (App s e1) e fe = [|| $$(fe s) $$(evalStaged e1 e fe) ||] | |
evalStaged (Add e1 e2) e fe = [|| $$(evalStaged e1 e fe) + $$(evalStaged e2 e fe) ||] | |
evalStaged (Sub e1 e2) e fe = [|| $$(evalStaged e1 e fe) - $$(evalStaged e2 e fe) ||] | |
evalStaged (Mul e1 e2) e fe = [|| $$(evalStaged e1 e fe) * $$(evalStaged e2 e fe) ||] | |
evalStaged (If e1 e2 e3) env fe = | |
[|| if $$(evalStaged e1 env fe) == 0 | |
then $$(evalStaged e2 env fe) | |
else $$(evalStaged e3 env fe) ||] | |
--test :: Ops r => (Int -> r Int) -> r ( | |
pevalStaged :: [Decl] -> Expr -> EnvStaged -> FEnvStaged -> Q (TExp Int) | |
pevalStaged [] e env fenv = evalStaged e env fenv | |
pevalStaged ((s1, s2, e1): ds) e env fenv = | |
[|| let f xx = $$(evalStaged e1 (ext2 env s2 [|| xx ||]) (ext2 fenv s1 [|| f ||])) | |
in $$(pevalStaged ds e env (ext2 fenv s1 [|| f ||])) ||] | |
type EnvStaged = String -> Q (TExp Int) | |
type FEnvStaged = String -> Q (TExp (Int -> Int)) | |
ext2 env key val x = if x == key then val else env x | |
env0Staged = \x -> error "undef" | |
fenv0Staged = \x -> error "undef f" | |
example :: Q (TExp Int) | |
example = pevalStaged fact e env0Staged fenv0Staged | |
example2 :: Q (TExp Int) | |
example2 = unwrap $ evalFact env0 fenv0 | |
example3 :: Q (TExp Int) | |
example3 = unwrap $ lowerCIf $ evalFact env0 fenv0 | |
eval :: forall r . Ops r => Int -> Expr -> Env r -> FEnv r -> r Int | |
eval 0 e fe env = _int 0 -- evalUnstaged e fe env -- _int (runIdentity $ evalUnstaged e fe env) | |
eval fuel (EInt n) _ _ = _int n | |
eval fuel (Var s) e _ = getRenv e s | |
eval fuel (App s e1) e fe = _app (getRenv fe s) (eval (fuel - 1) e1 e fe) | |
eval fuel (Add e1 e2) e fe = eval (fuel - 1) e1 e fe `_plus` eval (fuel - 1) e2 e fe | |
eval fuel (Sub e1 e2) e fe = eval (fuel - 1) e1 e fe `_minus` eval (fuel - 1) e2 e fe | |
eval fuel (Mul e1 e2) e fe = (eval (fuel - 1) e1 e fe `_mul` eval (fuel - 1) e2 e fe) | |
eval fuel (If e1 e2 e3) env fe = | |
_if ((eval (fuel - 1) e1 env fe) `_eq` (_int 0)) | |
(eval (fuel - 1) e2 env fe) | |
(eval (fuel - 1) e3 env fe) | |
eval fuel (Eq e1 e2) env fe = _if (eval (fuel - 1) e1 env fe `_eq` eval (fuel - 1) e2 env fe) | |
(_int 0) | |
(_int 1) | |
eval fuel (Mod e1 e2) env fe = eval (fuel - 1) e1 env fe `_mod` eval (fuel - 1) e2 env fe | |
eval fuel (Div e1 e2) env fe = eval (fuel - 1) e1 env fe `_div` eval (fuel - 1) e2 env fe | |
type Decl = (String, String, Expr) | |
ext :: REnv r a -> String -> r a -> REnv r a | |
ext (REnv env) key v = REnv (\y -> if key == y then v else env y) | |
env0 = REnv (\x -> error "undef") | |
fenv0 = REnv (\x -> error "undef f") | |
peval :: Ops r => [Decl] -> Expr -> Env r -> FEnv r -> r Int | |
peval [] e env fenv = eval 100 e env fenv | |
peval ((s1, s2, e1) : ds) e env fenv = | |
_lam2 | |
(\f -> (\x -> eval 100 e1 (ext env s2 x) (ext fenv s1 f))) | |
(\f -> peval ds e env (ext fenv s1 f)) | |
-- where | |
-- f :: forall r . Ops r => r (Int -> Int) | |
-- f = _lam (\x -> eval 2 e1 (ext env s2 (unsafeCoerce x)) (ext fenv s1 f)) | |
factdef = If (Eq (EInt 0) (Var "x")) (EInt 1) (Mul (Var "x") (App "fact" (Sub (Var "x") (EInt 1)))) | |
collatz = If (Eq (Mod (Var "x") (EInt 2) ) (EInt 0)) | |
(App "collatz" (Div (Var "x") (EInt 2))) | |
(App "collatz" (Add (Mul (EInt 3) (Var "x")) (EInt 1))) | |
collatz_wrapper = If (Eq (Var "x") (EInt 1)) | |
(EInt 0) | |
collatz | |
fact :: [Decl] | |
fact = [("fact", "x", factdef), ("collatz", "x", collatz_wrapper)] | |
e = App "fact" (EInt 5) | |
evalFact :: Ops r => Env r -> FEnv r -> r Int | |
evalFact = peval fact e | |
main :: Identity Int | |
main = peval fact e env0 fenv0 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment