Skip to content

Instantly share code, notes, and snippets.

@mpickering
Created April 19, 2018 16:54
Show Gist options
  • Save mpickering/b6935a63717129cc33b77774c126168a to your computer and use it in GitHub Desktop.
Save mpickering/b6935a63717129cc33b77774c126168a to your computer and use it in GitHub Desktop.
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
module LMS where
import Data.Generics.Product
import GHC.Generics
import Control.Lens (view)
import Debug.Trace
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Unsafe.Coerce
{-
The essence of LMS is
1) Virtualised syntax
2) Implicit lifting
Here is how we might embed it in Haskell
-}
class Ops r where
_if :: r Bool -> r a -> r a -> r a
_app :: r (a -> b) -> r a -> r b
_lam :: (r a -> r b) -> r (a -> b)
_lam2 :: (r (Int -> Int) -> r Int -> r Int) -> (r (Int -> Int) -> r Int) -> r Int
_int :: Int -> r Int
_string :: String -> r String
_mul :: r Int -> r Int -> r Int
_mod :: r Int -> r Int -> r Int
_div :: r Int -> r Int -> r Int
_minus :: r Int -> r Int -> r Int
_plus :: r Int -> r Int -> r Int
_eq :: Eq a => r a -> r a -> r Bool
_let :: r a -> (r a -> r b) -> r b
_fix :: (r a -> r a) -> r a
_quote :: Lift a => a -> r a
liftInt :: Ops r => Identity Int -> r Int
liftInt = _int . runIdentity
data Syn a where
If_ :: Syn Bool -> Syn a -> Syn a -> Syn a
App_ :: Syn (a -> b) -> Syn a -> Syn b
Lam_ :: (Syn a -> Syn b) -> Syn (a -> b)
Lam2_ :: (Syn (Int -> Int) -> Syn Int -> Syn Int) -> (Syn (Int -> Int) -> Syn Int) -> Syn Int
Int_ :: Int -> Syn Int
String_ :: String -> Syn String
Mul_ :: Syn Int -> Syn Int -> Syn Int
Plus_ :: Syn Int -> Syn Int -> Syn Int
Minus_ :: Syn Int -> Syn Int -> Syn Int
Div_ :: Syn Int -> Syn Int -> Syn Int
Mod_ :: Syn Int -> Syn Int -> Syn Int
Eq_ :: Eq a => Syn a -> Syn a -> Syn Bool
Let_ :: Syn a -> (Syn a -> Syn b) -> Syn b
Fix_ :: (Syn a -> Syn a) -> Syn a
Quote_ :: a -> Syn a
newtype Identity a = Identity a deriving (Generic, Show)
instance Functor Identity where
fmap f (Identity a) = Identity (f a)
instance Applicative Identity where
pure = Identity
(Identity f) <*> (Identity x) = Identity (f x)
runIdentity :: Identity a -> a
runIdentity = view (the @1)
unwrap :: TTExp a -> Q (TExp a)
unwrap = view (the @1)
instance Ops Identity where
_if (Identity b) e1 e2 = Identity (if b then runIdentity e1 else runIdentity e2)
_app (Identity f) (Identity a) = Identity (f a)
_lam f = Identity (\a -> runIdentity (f (Identity a)))
_lam2 f1 f2 = f2 (Identity f)
where
f :: Int -> Int
f x = runIdentity $ f1 (Identity f) (Identity x)
_int n = Identity n
_string s = Identity s
_mul (Identity n1) (Identity n2) = Identity (n1 * n2)
_plus (Identity n1) (Identity n2) = Identity (n1 + n2)
_minus (Identity n1) (Identity n2) = Identity (n1 - n2)
_eq (Identity n1) (Identity n2) = Identity (n1 == n2)
_mod n1 n2 = mod <$> n1 <*> n2
_div n1 n2 = div <$> n1 <*> n2
_let l f = (f l)
_fix f = (fix f)
_quote a = Identity a
instance Ops Syn where
_if = If_
_app = App_
_lam = Lam_
_lam2 = Lam2_
_int = Int_
_string = String_
_mul = Mul_
_plus = Plus_
_minus = Minus_
_div = Div_
_mod = Mod_
_eq = Eq_
_let = Let_
_fix = Fix_
_quote = Quote_
newtype TTExp a = TTExp (Q (TExp a)) deriving Generic
--instance Functor TTExp where
-- fmap f (TTExp te) = TTExp $ [|| f $$(te) ||]
liftT :: Lift a => a -> TExpQ a
liftT = unsafeTExpCoerce . lift
fix f = let x = f x in x
instance Ops TTExp where
_if (TTExp te1) ~(TTExp te2) ~(TTExp te3) = TTExp $
do
qRunIO $ putStrLn "IF"
[|| if $$te1 then $$te2 else $$te3 ||]
_app ~(TTExp te1) ~(TTExp te2) = TTExp $ [|| $$te1 $$te2 ||]
_lam f = TTExp $ [|| \a -> $$(unwrap $ f (TTExp [|| a ||])) ||]
_lam2 f1 f2 = TTExp $ [|| let f x = $$(unwrap $ f1 (TTExp $ [|| f ||]) (TTExp $ [|| x ||]))
in $$(unwrap $ f2 (TTExp [|| f ||])) ||]
_int n = TTExp $ liftT n
_string s = TTExp $ liftT s
_mul ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 * $$m2 ||]
_plus ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 + $$m2 ||]
_minus ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 - $$m2 ||]
_div ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 `div` $$m2 ||]
_mod ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 `mod` $$m2 ||]
_eq ~(TTExp m1) ~(TTExp m2) = TTExp $ [|| $$m1 == $$m2 ||]
_let m1 m2 = m2 m1
_fix f = fix f
_quote x = TTExp $ [|| x ||]
-- Performs the transformation
-- if (if cond then 0 else 1) == 0 => if cond
data CollapseIf p a where
Dyn :: p a -> CollapseIf p a
Num :: Int -> CollapseIf p Int
CIf :: p Bool -> CollapseIf p Int
CEq :: p Bool -> CollapseIf p Bool
lowerCIf :: Ops p => CollapseIf p a -> p a
lowerCIf (Dyn pa) = pa
lowerCIf (Num n) = _int n
lowerCIf (CIf pb) = _if pb (_int 0) (_int 1)
lowerCIf (CEq pb) = (_if pb (_int 0) (_int 1)) `_eq` _int 0
liftCIf :: p a -> CollapseIf p a
liftCIf = Dyn
instance Ops p => Ops (CollapseIf p) where
_if pa (Num 0) (Num 1) = CIf (lowerCIf pa)
_if (CEq pa) p1 p2 = Dyn (_if pa (lowerCIf p1) (lowerCIf p2))
_if pa c1 c2 = Dyn (_if (lowerCIf pa) (lowerCIf c1) (lowerCIf c2))
_int k = Num k
_eq (CIf pa) (Num 0) = CEq pa
_eq pa pb = Dyn (_eq (lowerCIf pa) (lowerCIf pb))
_app pa pb = Dyn (_app (lowerCIf pa) (lowerCIf pb))
_lam pa = Dyn (_lam (lowerCIf . pa . liftCIf ))
_lam2 cpii p = Dyn (_lam2 (\a1 a2 -> lowerCIf (cpii (liftCIf a1) (liftCIf a2)))
(lowerCIf . p . liftCIf))
_string s = Dyn (_string s)
_mul pa pb = Dyn (_mul (lowerCIf pa) (lowerCIf pb))
_mod pa pb = Dyn (_mod (lowerCIf pa) (lowerCIf pb))
_div pa pb = Dyn (_div (lowerCIf pa) (lowerCIf pb))
_minus pa pb = Dyn (_minus (lowerCIf pa) (lowerCIf pb))
_plus pa pb = Dyn (_plus (lowerCIf pa) (lowerCIf pb))
_let pa pab = (pab pa)
_fix = fix
_quote = Dyn . _quote
power :: Ops r => Int -> r Int -> r Int
power n k =
if n == 0
then (_int 1)
else (_mul k (power (n-1) k))
{- Wrong because we don't evaluate the condition statically, even though we
- can
_if ((_eq (_int n) (_int 0)))
(_int 1)
(_mul k (power (n-1) k))
-}
_lamEq :: Ops r => r (Int -> Int -> Bool)
_lamEq = _lam (\ri -> _lam (\ri2 -> _eq ri ri2))
_lamMul :: Ops r => r (Int -> Int -> Int)
_lamMul = _lam (\ri -> _lam (\ri2 -> _mul ri ri2))
-- Even more overloading, application
power2 :: Ops r => Int -> r Int -> r Int
power2 n k = _if (_app (_app _lamEq (_int n)) (_int 0))
(_int 1)
(_app (_app _lamMul k) (power2 (n-1) k))
-- Even more overloading, abstraction
power3 :: Ops r => Int -> r (Int -> Int)
power3 n = _lam (\k -> _if (_app (_app _lamEq (_int n)) (_int 0))
(_int 1)
(_app (_app _lamMul k) (power2 (n-1) k)))
test = powerIden 5 4
powerIden k n = runIdentity (power @Identity k (Identity n))
power2Iden k n = runIdentity (power2 @Identity k (Identity n))
power3Iden k n = runIdentity (_app (power3 k) (_int n))
power3Staged :: Int -> TTExp (Int -> Int)
power3Staged n = power3 n
power1Staged :: Int -> TTExp Int -> TTExp Int
power1Staged n k = power n k
powerSyn = power @Syn
power2Syn = power2 @Syn
data Expr = EInt Int | Var String
| App String Expr
| Add Expr Expr
| Sub Expr Expr
| Mul Expr Expr
| If Expr Expr Expr
| Eq Expr Expr
| Mod Expr Expr
| Div Expr Expr
data REnv r a = REnv (String -> r a)
getRenv :: Ops r => REnv r a -> String -> r a
getRenv (REnv r) a = r a
type Env r = REnv r Int
type FEnv r = REnv r (Int -> Int)
evalUnstaged :: Expr -> Env Identity -> FEnv Identity -> Identity Int
evalUnstaged (EInt n) _ _ = _int n
evalUnstaged (Var s) e _ = getRenv e s
evalUnstaged (App s e1) e fe = _app (getRenv fe s) (evalUnstaged e1 e fe)
evalUnstaged (Add e1 e2) e fe = evalUnstaged e1 e fe `_plus` evalUnstaged e2 e fe
evalUnstaged (Sub e1 e2) e fe = evalUnstaged e1 e fe `_minus` evalUnstaged e2 e fe
evalUnstaged (Mul e1 e2) e fe = (evalUnstaged e1 e fe `_mul` evalUnstaged e2 e fe)
evalUnstaged (If e1 e2 e3) env fe =
_if ((evalUnstaged e1 env fe) `_eq` _int 0)
(evalUnstaged e2 env fe)
(evalUnstaged e3 env fe)
evalStaged :: Expr -> EnvStaged -> FEnvStaged -> Q (TExp Int)
evalStaged (EInt n) _ _ = liftT n
evalStaged (Var s) e _ = e s
evalStaged (App s e1) e fe = [|| $$(fe s) $$(evalStaged e1 e fe) ||]
evalStaged (Add e1 e2) e fe = [|| $$(evalStaged e1 e fe) + $$(evalStaged e2 e fe) ||]
evalStaged (Sub e1 e2) e fe = [|| $$(evalStaged e1 e fe) - $$(evalStaged e2 e fe) ||]
evalStaged (Mul e1 e2) e fe = [|| $$(evalStaged e1 e fe) * $$(evalStaged e2 e fe) ||]
evalStaged (If e1 e2 e3) env fe =
[|| if $$(evalStaged e1 env fe) == 0
then $$(evalStaged e2 env fe)
else $$(evalStaged e3 env fe) ||]
--test :: Ops r => (Int -> r Int) -> r (
pevalStaged :: [Decl] -> Expr -> EnvStaged -> FEnvStaged -> Q (TExp Int)
pevalStaged [] e env fenv = evalStaged e env fenv
pevalStaged ((s1, s2, e1): ds) e env fenv =
[|| let f xx = $$(evalStaged e1 (ext2 env s2 [|| xx ||]) (ext2 fenv s1 [|| f ||]))
in $$(pevalStaged ds e env (ext2 fenv s1 [|| f ||])) ||]
type EnvStaged = String -> Q (TExp Int)
type FEnvStaged = String -> Q (TExp (Int -> Int))
ext2 env key val x = if x == key then val else env x
env0Staged = \x -> error "undef"
fenv0Staged = \x -> error "undef f"
example :: Q (TExp Int)
example = pevalStaged fact e env0Staged fenv0Staged
example2 :: Q (TExp Int)
example2 = unwrap $ evalFact env0 fenv0
example3 :: Q (TExp Int)
example3 = unwrap $ lowerCIf $ evalFact env0 fenv0
eval :: forall r . Ops r => Int -> Expr -> Env r -> FEnv r -> r Int
eval 0 e fe env = _int 0 -- evalUnstaged e fe env -- _int (runIdentity $ evalUnstaged e fe env)
eval fuel (EInt n) _ _ = _int n
eval fuel (Var s) e _ = getRenv e s
eval fuel (App s e1) e fe = _app (getRenv fe s) (eval (fuel - 1) e1 e fe)
eval fuel (Add e1 e2) e fe = eval (fuel - 1) e1 e fe `_plus` eval (fuel - 1) e2 e fe
eval fuel (Sub e1 e2) e fe = eval (fuel - 1) e1 e fe `_minus` eval (fuel - 1) e2 e fe
eval fuel (Mul e1 e2) e fe = (eval (fuel - 1) e1 e fe `_mul` eval (fuel - 1) e2 e fe)
eval fuel (If e1 e2 e3) env fe =
_if ((eval (fuel - 1) e1 env fe) `_eq` (_int 0))
(eval (fuel - 1) e2 env fe)
(eval (fuel - 1) e3 env fe)
eval fuel (Eq e1 e2) env fe = _if (eval (fuel - 1) e1 env fe `_eq` eval (fuel - 1) e2 env fe)
(_int 0)
(_int 1)
eval fuel (Mod e1 e2) env fe = eval (fuel - 1) e1 env fe `_mod` eval (fuel - 1) e2 env fe
eval fuel (Div e1 e2) env fe = eval (fuel - 1) e1 env fe `_div` eval (fuel - 1) e2 env fe
type Decl = (String, String, Expr)
ext :: REnv r a -> String -> r a -> REnv r a
ext (REnv env) key v = REnv (\y -> if key == y then v else env y)
env0 = REnv (\x -> error "undef")
fenv0 = REnv (\x -> error "undef f")
peval :: Ops r => [Decl] -> Expr -> Env r -> FEnv r -> r Int
peval [] e env fenv = eval 100 e env fenv
peval ((s1, s2, e1) : ds) e env fenv =
_lam2
(\f -> (\x -> eval 100 e1 (ext env s2 x) (ext fenv s1 f)))
(\f -> peval ds e env (ext fenv s1 f))
-- where
-- f :: forall r . Ops r => r (Int -> Int)
-- f = _lam (\x -> eval 2 e1 (ext env s2 (unsafeCoerce x)) (ext fenv s1 f))
factdef = If (Eq (EInt 0) (Var "x")) (EInt 1) (Mul (Var "x") (App "fact" (Sub (Var "x") (EInt 1))))
collatz = If (Eq (Mod (Var "x") (EInt 2) ) (EInt 0))
(App "collatz" (Div (Var "x") (EInt 2)))
(App "collatz" (Add (Mul (EInt 3) (Var "x")) (EInt 1)))
collatz_wrapper = If (Eq (Var "x") (EInt 1))
(EInt 0)
collatz
fact :: [Decl]
fact = [("fact", "x", factdef), ("collatz", "x", collatz_wrapper)]
e = App "fact" (EInt 5)
evalFact :: Ops r => Env r -> FEnv r -> r Int
evalFact = peval fact e
main :: Identity Int
main = peval fact e env0 fenv0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment