Created
July 16, 2014 15:52
-
-
Save christiaanb/231f18ad7daddb521f37 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-} | |
{-# OPTIONS_GHC -fno-full-laziness #-} | |
module Main where | |
import Bound | |
import Bound.Name | |
import Bound.Var | |
import Control.Applicative | |
import Control.Comonad | |
import Control.Monad | |
import Data.Foldable | |
import Data.Functor.Yoneda | |
import Data.Traversable | |
import Prelude.Extras | |
import Data.IORef | |
import System.IO.Unsafe | |
import Control.DeepSeq | |
import Debug.Trace | |
import Criterion.Main | |
import Criterion.Config | |
{-# NOINLINE tcounter #-} | |
tcounter :: IORef Int | |
tcounter = unsafePerformIO (newIORef 0) | |
{-# NOINLINE incT #-} | |
incT :: () -> Int | |
incT () = unsafePerformIO $ do | |
t <- readIORef tcounter | |
writeIORef tcounter (t + 1) | |
return t | |
resetT :: () -> () | |
resetT () = unsafePerformIO $ writeIORef tcounter 0 | |
readT :: () -> Int | |
readT () = unsafePerformIO $ readIORef tcounter | |
data Binder n a | |
= Lam {binder :: Name n a} | |
| Pi {binder :: Name n a} | |
deriving (Functor,Show,Eq) | |
instance (NFData n, NFData a) => NFData (Name n a) where | |
rnf (Name n b) = rnf n `seq` rnf b `seq` () | |
instance (NFData n, NFData a) => NFData (Binder n a) where | |
rnf tm = case tm of | |
Lam b -> rnf b `seq` () | |
Pi b -> rnf b `seq` () | |
instance Comonad (Binder n) where | |
extract = extract . binder | |
extend f w = fmap (const (f w)) w | |
data Term n a | |
= Var !a | |
| Universe Integer | |
| App !(Term n a) !(Term n a) | |
| Bind !(Binder n (Term n a)) !(Scope (Name n ()) (Term n) a) | |
deriving (Eq,Show) | |
instance Eq1 (Term n) | |
instance Show n => Show1 (Term n) | |
instance (NFData b, NFData a) => NFData (Var b a) where | |
rnf (F a) = rnf a `seq` () | |
rnf (B b) = rnf b `seq` () | |
instance (NFData n, NFData a) => NFData (Term n a) where | |
rnf tm = case tm of | |
Var a -> rnf a `seq` () | |
Universe i -> rnf i `seq` () | |
App l r -> rnf l `seq` rnf r `seq` () | |
Bind b (Scope s) -> rnf b `seq` rnf s `seq` () | |
instance Functor (Term n) where | |
fmap f tm = incT () `seq` case tm of | |
(Var a) -> Var (f a) | |
(Universe i) -> Universe i | |
(App l r) -> App (fmap f l) (fmap f r) | |
(Bind b s) -> Bind (fmap (fmap f) b) (fmap f s) | |
instance Applicative (Term n) where | |
pure = Var | |
(<*>) = ap | |
instance Monad (Term n) where | |
return = Var | |
(>>=) = bindTerm | |
bindTerm :: Term n a -> (a -> Term n b) -> Term n b | |
bindTerm tm f = incT () `seq` case tm of | |
(Var a) -> f a | |
(Universe i) -> Universe i | |
(App e1 e2) -> App (bindTerm e1 f) (bindTerm e2 f) | |
(Bind b s) -> Bind (fmap (`bindTerm` f) b) (s >>>= f) | |
inferPi :: Eq a => Show n => Show a | |
=> Term n a | |
-> (n,Term n a,Scope (Name n ()) (Term n) a) | |
inferPi ty = case ty of | |
Bind (Pi b) s -> (name b, extract b, s) | |
_ -> error ("Function expected: " ++ show ty) | |
inferUniverse :: Eq a => Show n => Show a | |
=> Term n a | |
-> Integer | |
inferUniverse ty = case ty of | |
Universe i -> i | |
_ -> error ("Type expected: " ++ show ty) | |
-- | inferType1: many traversals using 'toScope' and 'fromScope' | |
inferType1 :: Eq a => Show n => Show a | |
=> (a -> Term n a) -- ^ Context | |
-> Term n a -- ^ Term | |
-> Term n a -- ^ Inferred type | |
inferType1 ctx (Var a) = ctx a | |
inferType1 _ (Universe i) = Universe (i+1) | |
inferType1 ctx (App e1 e2) = if s == te then instantiate1Name e2 t | |
else error "Mismatch" | |
where | |
te = inferType1 ctx e2 | |
(_,s,t) = inferPi (inferType1 ctx e1) | |
inferType1 ctx (Bind (Pi b) s) = Universe (max k1 k2) | |
where | |
t = extract b | |
k1 = inferUniverse (inferType1 ctx t) | |
k2 = inferUniverse (inferType1 ctx (instantiate1Name t s)) | |
inferType1 ctx (Bind (Lam b) s) = Bind (Pi b) s' | |
where | |
s' = toScope . inferType1 ctx' . fromScope $ s | |
ctx' = unvar bCtx fCtx | |
bCtx _ = fmap F . extract $ b | |
fCtx = fmap F . ctx | |
-- | inferType2: Only traversals in new context | |
inferType2 :: Eq a => Show n => Show a | |
=> (a -> Term n a) -- ^ Context | |
-> Term n a -- ^ Term | |
-> Term n a -- ^ Inferred type | |
inferType2 ctx (Var a) = ctx a | |
inferType2 _ (Universe i) = Universe (i+1) | |
inferType2 ctx (App e1 e2) = if s == te then instantiate1Name e2 t | |
else error "Mismatch" | |
where | |
te = inferType2 ctx e2 | |
(_,s,t) = inferPi (inferType2 ctx e1) | |
inferType2 ctx (Bind (Pi b) s) = Universe (max k1 k2) | |
where | |
t = extract b | |
k1 = inferUniverse (inferType2 ctx t) | |
k2 = inferUniverse (inferType2 ctx (instantiate1Name t s)) | |
inferType2 ctx (Bind (Lam b) s) = Bind (Pi b) s' | |
where | |
s' = Scope . inferType2 ctx' . unscope $ s | |
ctx' = unvar bCtx fCtx | |
bCtx _ = fmap (F . Var) . extract $ b | |
fCtx = fmap (F . Var) . inferType2 ctx | |
-- | inferType3: delay traversal of the context, trying to get Yoneda to merge | |
-- the 'fmap (F . Var)' created by dist'. Hopefully better than 'inferType2' | |
type YTerm n a = Yoneda (Term n) a | |
inferType3 :: Eq a => Show n => Show a | |
=> (a -> Term n a) -- ^ Context | |
-> (Term n a -> YTerm n a) -- Quotient and distribute (Var . F) | |
-> Term n a -- ^ Term | |
-> Term n a -- ^ Inferred type | |
inferType3 ctx dist (Var a) = ctx a | |
inferType3 _ dist (Universe i) = Universe (i+1) | |
inferType3 ctx dist (App e1 e2) = if s == te then instantiate1Name e2 t | |
else error "Mismatch" | |
where | |
te = lowerYoneda . dist $ inferType3 ctx dist e2 | |
(_,s,t) = inferPi (lowerYoneda . dist $ inferType3 ctx dist e1) | |
inferType3 ctx dist (Bind (Pi b) s) = Universe (max k1 k2) | |
where | |
t = extract b | |
k1 = inferUniverse (inferType3 ctx dist t) | |
k2 = inferUniverse (inferType3 ctx dist (instantiate1Name t s)) | |
inferType3 ctx dist (Bind (Lam b) s) = Bind (Pi b) s' | |
where | |
s' = Scope . inferType3 ctx' dist' . unscope $ s | |
ctx' = unvar bCtx fCtx | |
bCtx _ = Var . F . extract $ b | |
fCtx = Var . F . inferType3 ctx dist | |
dist' (Var (F a)) = fmap (F . Var) . dist $ a | |
dist' e = liftYoneda e | |
-- Example: | |
type LVar = String | |
type CoreTerm = Term LVar LVar | |
uni :: CoreTerm | |
uni = Universe 0 | |
lam :: (LVar,CoreTerm) -> CoreTerm -> CoreTerm | |
lam (v,b) e = Bind (Lam (Name v b)) (abstract1Name v e) | |
pi_ :: (LVar,CoreTerm) -> CoreTerm -> CoreTerm | |
pi_ (v,b) e = Bind (Pi (Name v b)) (abstract1Name v e) | |
dollar :: CoreTerm | |
dollar = lam ("p",uni) $ lam ("q",uni) $ lam ("m",pi_("_",uni) uni) $ lam ("a",uni) $ lam ("b",uni) | |
$ lam ("c",pi_("_",uni) uni) | |
$ lam ("d",pi_("_",uni) (pi_("_",uni) uni)) | |
$ lam ("e",pi_("_",uni) uni) | |
$ lam ("g",pi_("_",uni) (pi_("_",uni) (pi_("_",uni) uni))) | |
$ lam ("h",pi_("_",uni) uni) | |
$ lam ("g",pi_("_",uni) (pi_("_",uni) (pi_("_",uni) uni))) | |
$ lam ("f",pi_ ("_",App (App (Var "g") (Var "a")) (Var "c")) (Var "b")) | |
$ lam ("f1",pi_ ("_",App (App (Var "g") (Var "a")) (Var "c")) (Var "b")) | |
$ lam ("f2",pi_ ("_",App (App (Var "g") (Var "a")) (Var "c")) (Var "b")) | |
$ lam ("f3",pi_ ("_",App (App (Var "g") (Var "a")) (Var "c")) (Var "b")) | |
$ lam ("f4",pi_ ("_",App (App (Var "g") (Var "a")) (Var "c")) (Var "b")) | |
$ lam ("f5",pi_ ("_",App (App (Var "g") (Var "a")) (Var "c")) (Var "b")) | |
$ lam ("f6",pi_ ("_",App (App (Var "g") (Var "a")) (Var "c")) (Var "b")) | |
$ lam ("f7",pi_ ("_",App (App (Var "g") (Var "a")) (Var "c")) (Var "b")) | |
$ lam ("f8",pi_ ("_",App (App (Var "g") (Var "a")) (Var "c")) (Var "b")) | |
$ lam ("k",pi_ ("_",Var "b") (Var "e")) | |
$ lam ("y",pi_ ("_",Var "e") (Var "h")) | |
$ lam ("x",App (App (Var "g") (Var "a")) (Var "c")) | |
$ App (Var "y") (App (Var "k") (App (Var "f") (Var "x"))) | |
main = defaultMainWith (defaultConfig {cfgPerformGC = ljust True, cfgSamples = ljust 500}) (return ()) | |
[ bcompare [ bench "inferType1" $ nf (inferType1 (const undefined)) dollar | |
, bench "inferType2" $ nf (inferType2 (const undefined)) dollar | |
, bench "inferType3" $ nf (inferType3 (const undefined) liftYoneda) dollar | |
] | |
] | |
-- main :: IO () | |
-- main = do | |
-- writeIORef tcounter 0 | |
-- let z = dollar | |
-- d1 <- z `deepseq` readIORef tcounter | |
-- putStrLn ("dollar(" ++ show d1 ++ "): " ++ show z) | |
-- writeIORef tcounter 0 | |
-- let ty1 = inferType1 (const undefined) z | |
-- d2 <- ty1 `deepseq` readIORef tcounter | |
-- putStrLn ("inferType1(" ++ show d2 ++ "): " ++ show ty1) | |
-- writeIORef tcounter 0 | |
-- let ty2 = inferType2 (const undefined) z | |
-- d3 <- ty2 `deepseq` readIORef tcounter | |
-- putStrLn ("inferType2(" ++ show d3 ++ "): " ++ show ty2) | |
-- writeIORef tcounter 0 | |
-- let ty3 = inferType3 (const undefined) liftYoneda z | |
-- d4 <- ty3 `deepseq` readIORef tcounter | |
-- putStrLn ("inferType3(" ++ show d4 ++ "): " ++ show ty3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment