Last active
June 14, 2023 20:52
-
-
Save aradarbel10/9f5259f4b52b9dae804483517a1cf868 to your computer and use it in GitHub Desktop.
Type directed program synthesis for STLC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} | |
{-# HLINT ignore "Use list comprehension" #-} | |
import Data.IORef ( IORef, newIORef, readIORef, writeIORef ) | |
import GHC.IO ( unsafePerformIO ) | |
import Control.Applicative ( Alternative(..) ) | |
import Control.Monad ( when ) | |
import Debug.Trace | |
--- global fresh name source --- | |
type Name = String | |
freshi :: IORef Int | |
freshi = unsafePerformIO $ newIORef 0 | |
{-# NOINLINE freshi #-} | |
nexti :: () -> IO Int | |
nexti () = do | |
i <- readIORef freshi | |
writeIORef freshi (i + 1) | |
return i | |
freshen :: Name -> Name | |
freshen str = str ++ show (unsafePerformIO $ nexti ()) | |
--- nondeterministic computations --- | |
type Nondet a = [a] | |
--- language description --- | |
data Ty = Base Name | Arrow Ty Ty | Prod Ty Ty | |
deriving (Show, Eq) | |
data Vl = Neut Ne | Lam Name Vl | Pair Vl Vl | |
data Ne = Var Name | App Ne Vl | Fst Ne | Snd Ne | |
--- pretty printing --- | |
instance Show Vl where | |
show :: Vl -> String | |
show (Neut n) = "(" ++ show n ++ ")" | |
show (Lam x v) = "(\\" ++ x ++ ". " ++ show v ++ ")" | |
show (Pair v1 v2) = "(" ++ show v1 ++ ", " ++ show v2 ++ ")" | |
instance Show Ne where | |
show :: Ne -> String | |
show (Var x) = x | |
show (App n v) = show n ++ " " ++ show v | |
show (Fst n) = show n ++ " .1" | |
show (Snd n) = show n ++ " .2" | |
--- typing contexts --- | |
type Ctx = [(Name, Ty)] | |
assume :: Ctx -> Name -> Ty -> Ctx | |
assume ctx x t = (x, t) : ctx | |
--- eliminator shapes --- | |
-- this just keeps track of the shape, not the eliminated term itself, | |
-- but once we have the shape we can very easily apply concrete eliminators later on. | |
-- note the *head* of the shape is the eliminator applied *first*. | |
data Elim = EApp Ty | EFst | ESnd | |
deriving Show | |
type Shape = [Elim] | |
--- program synthesis --- | |
-- this is the main part of the code, implementing type-directed program synthesis, or equivalently | |
-- type directed proof search. it's based on two functions, `synth` and `search`, which roughly follow | |
-- the same pattern as bidirectional typechecking but are relationally dual. | |
-- `synth` takes a context and a goal type, and tries to build an introduction form of that type. | |
-- it is the program-synthesis analogue of bidi's `check`. | |
synth :: Ctx -> Ty -> Nondet Vl | |
synth ctx (Arrow t1 t2) = do | |
let x = freshen "x" | |
body <- search (assume ctx x t1) t2 | |
return $ Lam x body | |
synth ctx (Prod t1 t2) = do | |
e1 <- search ctx t1 | |
e2 <- search ctx t2 | |
return $ Pair e1 e2 | |
-- in the last case, this type has no introduction forms and thus cannot be synthesized | |
synth ctx _ = [] | |
-- `search` takes a context and a goal type, and tries to search variables in the context fitting that goal, | |
-- possibly applying eliminators to potential variables. | |
-- it is the program-synthesis analogue of bidi's `infer`, and similarly falls back to `synth`. | |
search :: Ctx -> Ty -> Nondet Vl | |
search ctx goal = | |
searchCtx ctx <|> synth ctx goal | |
where | |
-- the context is searched in order | |
searchCtx :: Ctx -> Nondet Vl | |
searchCtx [] = [] | |
searchCtx ((x, s):rest) = (Neut <$> searchEntry x s) <|> searchCtx rest | |
-- each context entry is searched against by trying to either using it directly as a variable, | |
-- or applying some eliminators on it first. namely, the result will always be a neutral. | |
searchEntry :: Name -> Ty -> Nondet Ne | |
searchEntry x t = do | |
shape <- reachable goal t | |
applyShape (Var x) shape | |
applyShape :: Ne -> Shape -> Nondet Ne | |
applyShape n [] = return n | |
applyShape n (EApp t1 : shape) = do | |
arg <- search ctx t1 | |
applyShape (App n arg) shape | |
applyShape n (EFst : shape) = applyShape (Fst n) shape | |
applyShape n (ESnd : shape) = applyShape (Snd n) shape | |
-- before starting to apply eliminators on a context entry, we should check if the goal is even | |
-- reachable from its type. this is important because, eg, mindlessly eliminating function types will | |
-- require repeatedly synthesizing the argument of an application, but that loops infinitely. | |
reachable :: Ty -> Ty -> Nondet Shape | |
reachable goal t = | |
-- the goal might be immediately reachable | |
(if goal == t then [[]] else []) | |
-- otherwise maybe we can get closer to the goal by applying some eliminators | |
<|> case t of | |
Arrow t1 t2 -> (EApp t1 :) <$> reachable goal t2 | |
Prod t1 t2 -> ((EFst :) <$> reachable goal t1) <|> ((ESnd :) <$> reachable goal t2) | |
_ -> [] | |
--- aesthetic helpers --- | |
w, x, y, z :: Ty | |
(w, x, y, z) = (Base "W", Base "X", Base "Y", Base "Z") | |
infixr 5 ~> | |
(~>) :: Ty -> Ty -> Ty | |
(~>) = Arrow | |
--- examples --- | |
runExample :: Ctx -> Ty -> IO () | |
runExample ctx t = do | |
-- default cap | |
let sols = take 7 $ search ctx t | |
print sols | |
main = do | |
putStrLn "welcome to proof search!" | |
-- finding appropriate variables in scope | |
runExample [("a", x)] x | |
runExample [("a", x), ("f", y ~> z)] (y ~> z) | |
-- synthesizing under a binder | |
runExample [] (x ~> x) | |
runExample [] (x ~> y ~> x) | |
-- eliminating assumptions from the context | |
runExample [("a", Prod y z)] z | |
runExample [("a", x), ("f", x ~> y)] y | |
runExample [("a", x), ("b", x), ("f", x ~> x ~> y)] y | |
runExample [("h", y ~> z), ("g", x ~> y), ("f", w ~> x)] (w ~> z) | |
-- repeatedly composing a function with itself | |
runExample [("a", x ~> x)] (x ~> x) | |
-- applying a higher order function | |
runExample [("f", x ~> y), ("g", (x ~> y) ~> y ~> z)] (x ~> z) | |
-- this example came out really interesting! exercise for the reader: try to come up with as many | |
-- of your own solutions (destinct up to beta-eta equivalence) before looking at the synthesis results. | |
-- misc | |
runExample [] (Prod x y ~> Prod y x) | |
runExample [("b", Prod (Prod x (y ~> z)) (Prod (Prod y z) x))] z |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment