-
-
Save jdmonty/32232a3e1d948a331782f803c591c1f1 to your computer and use it in GitHub Desktop.
Partial implementation of LISP program synthesis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Analytic inductive programming, according to survey paper: | |
-- Kitzelmann, Emanuel. Inductive Programming: A Survey of Program Synthesis Techniques | |
-- Other related work: | |
-- D.R. Smith. The synthesis of LISP programs from examples: A survey. | |
-- P.D. Summers. A Methodology for LISP program construction from examples. | |
-- Author: Lingfeng Yang | |
import Data.List | |
import Control.Monad | |
import Control.Monad.Writer | |
-- Our language: LISP | |
-- Goal: Derive recursive programs that transform lists using car, cdr, cons, nil | |
-- We are given several input/output pairs. Our goal is to learn a function of the form | |
-- (define (f x) | |
-- (cond | |
-- (p1 (f1 x)) | |
-- (p2 (f2 x)) | |
-- ... | |
-- (else (f (d x))))) | |
-- i.e., a conditional with several predicates p1,p2 ... and corresponding | |
-- traces f1 f2... plus a recursive call to f, with some reduction d of the | |
-- input (to ensure termination). | |
-- The algorithm is divided into 3 stages: | |
-- 1. Learning traces from each input/output pair | |
-- 2. Learning predicates that distinguish each input to the proper trace | |
-- 3. Learning recursive calls | |
-- We start by defining S-expressions as an algebraic data type | |
-- H is for expressing one (or multiple)-hole contexts. | |
-- The other terms correspond to their counterparts in Lisp. | |
data SExpr = | |
-- Atoms | |
Nil | T | F | Atom String | -- Atoms | |
-- Lists | |
Cons SExpr SExpr | -- Lists | |
-- List _traversals_ | |
Car SExpr | Cdr SExpr | -- List _traversals_ | |
-- Function abstraction, variables, and applications | |
Lam String SExpr | Var String | App SExpr SExpr | | |
-- Built-in predicates | |
IsAtom SExpr | | |
-- Conditional expression | |
Cond [(SExpr, SExpr)] | | |
-- Hole (for manipulating SExprs) | |
H | |
deriving Eq | |
-- Our first interpreter: Printing terms | |
instance Show SExpr where | |
show (Cons x y) = "(" ++ show x ++ " . " ++ show y ++ ")" | |
show T = "T" | |
show F = "F" | |
show Nil = "()" | |
show (Atom s) = s | |
show (Car x) = "(car " ++ show x ++ ")" | |
show (Cdr x) = "(cdr " ++ show x ++ ")" | |
show (Var s) = "_" ++ s ++ "_" | |
show (Lam s e) = "(lambda (_" ++ s ++ "_) " ++ show e ++ ")" | |
show (App e1 e2) = "(" ++ show e1 ++ " " ++ show e2 ++ ")" | |
show (IsAtom e) = "(atom? " ++ show e ++ ")" | |
show (Cond pfs) = "(cond\n" ++ (snd (runWriter (write pfs))) ++ ")" where | |
write :: [(SExpr, SExpr)] -> Writer String () | |
write pfs = do | |
forM_ pfs $ \(pred, frag) -> do | |
tell $ "(" ++ show pred ++ " " ++ show frag ++ ")\n" | |
show H = "___" | |
-- The first step: building program traces (aka fragments) | |
-- The paper outlines the following strategy for obtaining a single program | |
-- that works on a single input/output pair: | |
-- 1. For each 'largest' subexpression of the output that can be found in the | |
-- input, take the associated basic function (composition of car/cdr), then | |
-- that serves as a part of the final program, because we can apply that basic | |
-- function to the input to obtain that part of the output. | |
-- 2. The final S-expression is then obtained by substituing sub-S-expressions | |
-- of the output with the basic functions of interest. The final S-expression | |
-- is a _function_ which, given the input, produces the output. | |
-- This operates on a single input, output pair. The paper mentions that we | |
-- first take each element in the input subexpression and pair it up with a | |
-- composition of Car/Cdr; each subexpression in the output that occurs in the input | |
-- may be then be replaced as a traversal applied to the input. | |
-- Function to obtain subexpressions | |
all_subexprs :: SExpr -> [SExpr] | |
all_subexprs expr = loop [] expr where | |
loop xs (Atom x) = (Atom x):xs | |
loop xs Nil = (Nil:xs) | |
loop xs T = T:xs | |
loop xs F = F:xs | |
loop xs cc@(Cons x y) = let | |
l1 = loop [] x | |
l2 = loop [] y | |
partial = l1 ++ l2 ++ xs in | |
(cc:partial) | |
varx = Var "x" | |
lamx e = Lam "x" e | |
-- Function to obtain subexpressions + traces to obtain that subexpression | |
all_subexpr_func :: SExpr -> [(SExpr, SExpr)] | |
all_subexpr_func expr = loop [] varx expr where | |
loop xs v (Atom x) = (Atom x, v):xs | |
loop xs v Nil = (Nil, v):xs | |
loop xs v T = (T, v):xs | |
loop xs v F = (F, v):xs | |
loop xs v cc@(Cons x y) = let | |
l1 = loop [] (Car v) x | |
l2 = loop [] (Cdr v) y | |
partial = l1 ++ l2 ++ xs in | |
(cc, v):partial | |
-- Remark: | |
-- This is exactly building up _zippers_ of a data structure (here, | |
-- S-expressions) | |
-- car/cdr would be the traversal functions | |
make_fragment :: SExpr -> SExpr -> SExpr | |
make_fragment input output = let | |
input_subexpr_funcs = all_subexpr_func input | |
loop :: [(SExpr, SExpr)] -> SExpr -> SExpr | |
loop sfs (Atom x) = case lookup (Atom x) sfs of | |
Just expr -> expr | |
Nothing -> error "Atom of output not found in input" | |
loop sfs Nil = case lookup Nil sfs of | |
Just expr -> expr | |
Nothing -> error "Nil of output not found in input" | |
loop sfs T = case lookup T sfs of | |
Just expr -> expr | |
Nothing -> error "T of output not found in input" | |
loop sfs F = case lookup F sfs of | |
Just expr -> expr | |
Nothing -> error "F of output not found in input" | |
loop sfs (Cons x y) = case lookup (Cons x y) sfs of | |
Just expr -> expr | |
Nothing -> let | |
res1 = loop sfs x | |
res2 = loop sfs y in | |
(Cons res1 res2) | |
in | |
loop input_subexpr_funcs output | |
-- The next phase: building predicates. | |
-- 1. Define a partial order on forms: any atom <= others, and (a . b) <= (c . | |
-- d) iff a <= c and b <= d. | |
instance Ord SExpr where | |
(Atom x) <= y = True | |
Nil <= y = True | |
T <= y = True | |
F <= y = True | |
(Cons a b) <= (Cons c d) = (a <= c) && (b <= d) | |
-- 2. Obtain a total order on the input examples. | |
-- 3. Then, derive a classifier for each input example, so that we associat the | |
-- correct output function with the correct input. | |
-- I could not find the actual algorithm to derive classifiers, but the Summers | |
-- paper provides enough of a skech that I think this is what it does: | |
-- Derive the classifier as follows: Given x <= y, we know that y is 'more | |
-- complicated' than x. Traverse x and y using the same traversal moves, in | |
-- parallel. At some sub-expression of x, x will be an atom and y will be | |
-- another cons cell. We then record atom(t1 . t2 . \ldots . tn (x)) as the | |
-- classifier. | |
-- Convenience function: _form_, to turn all atoms of an S-expression into the | |
-- same value (here, T), so we can compare just the structure | |
form :: SExpr -> SExpr | |
form (Cons x y) = (Cons (form x) (form y)) | |
form (Atom x) = T | |
form Nil = T | |
form T = T | |
form F = T | |
form _ = error "Forms other than cons cells not supported" | |
-- The actual algorithm to form a predicate. | |
-- precondition: the two s-exprs are such that e1 <= e2 | |
mk_classifier :: SExpr -> SExpr -> SExpr | |
mk_classifier e1 e2 = IsAtom $ head $ loop [] varx (form e1) (form e2) where | |
loop :: [SExpr] -> SExpr -> SExpr -> SExpr -> [SExpr] | |
loop xs v (Cons x1 y1) (Cons x2 y2) = let | |
l1 = loop [] (Car v) x1 x2 | |
l2 = loop [] (Cdr v) y1 y2 | |
partial = l1 ++ l2 ++ xs in | |
partial | |
-- This means e1 <= e2; i.e., there exists some subpart of e1 that is an atom, | |
-- but corresponds to a cons cell in e2. In this case this is (one of) the | |
-- answers we want and we _return_ it (List monad) | |
loop xs v T (Cons x2 y2) = [v] | |
-- Exercise: prove (or disprove) that this pattern-match can _never_ happen if e1 <= e2 | |
-- loop xs v (Cons x1 y1) T = [v] | |
-- This is a 'tie' between e1, e2. We are not interested in building these results | |
loop xs v T T = [] | |
-- The next step: | |
-- Finding recursions by finding recurrence relations. | |
-- Steps: From a properly ordered list of [(predicate, fragment)] pairs: | |
-- 1. Calculate the _difference_ between each pair | |
-- The difference function is the most complicated part. This can be broken | |
-- down into several stages. I'm sure there's probably a cleaner way to do | |
-- this, but long story short, by implementing this one has implemented a good | |
-- chunk of a Prolog interpreter. | |
-- First: A function to find one-hole contexts | |
-- A zipper data type: a particular node in the AST, plus the tree above and | |
-- below that now | |
type SExprZip = (SExpr, SExpr) | |
-- Our traversal functions | |
-- The type is [SExprZip] because it may fail | |
downleft :: SExprZip -> [SExprZip] | |
downleft (aboveh, below) = do | |
(newlhs, newrhs) <- hdownleft below | |
return (context_sub aboveh newlhs, newrhs) where | |
hdownleft Nil = mzero | |
hdownleft T = mzero | |
hdownleft F = mzero | |
hdownleft (Atom x) = mzero | |
hdownleft (Var x) = mzero | |
hdownleft (IsAtom x) = return ((IsAtom H), x) | |
hdownleft (Car x) = return ((Car H), x) | |
hdownleft (Cdr x) = return ((Cdr H), x) | |
hdownleft (Cons x y) = return ((Cons H y), x) | |
downright :: SExprZip -> [SExprZip] | |
downright (aboveh, below) = do | |
(newlhs, newrhs) <- hdownright below | |
return (context_sub aboveh newlhs, newrhs) where | |
hdownright Nil = mzero | |
hdownright T = mzero | |
hdownright F = mzero | |
hdownright (Atom x) = mzero | |
hdownright (Var x) = mzero | |
hdownright (IsAtom x) = mzero | |
hdownright (Car x) = mzero | |
hdownright (Cdr x) = mzero | |
hdownright (Cons x y) = return ((Cons x H), y) | |
-- Going backwards (up the tree) | |
unzipSExpr :: SExprZip -> SExpr | |
unzipSExpr (above, below) = context_sub above below | |
-- Now we have our function to obtain all 1-hole contexts | |
-- The 1-hole contexts of a term are all revisions of that term with a subterm | |
-- removed and replaced by H | |
-- Example: The contexts of ((a . b) . (c . d)) are | |
-- H, (H . (c . d), ((H . b) . (c . d)), ((a . b) . H), etc. | |
curr_context :: SExprZip -> SExpr | |
curr_context = fst | |
contexts :: SExpr -> [SExpr] | |
contexts e = let | |
start_zip = return (H, e) | |
loop :: [SExprZip] -> [SExprZip] | |
loop current_zip = do | |
z <- current_zip | |
loop (downleft z) `mplus` loop (downright z) `mplus` return z in | |
map curr_context $ loop start_zip | |
-- A function to substitute a s-expr in a s-expr with H's (replace H with that expression): | |
context_sub :: SExpr -> SExpr -> SExpr | |
context_sub H e2 = e2 | |
context_sub T e2 = T | |
context_sub F e2 = F | |
context_sub (Atom x) e2 = (Atom x) | |
context_sub (Var x) e2 = (Var x) | |
context_sub Nil e2 = Nil | |
context_sub (IsAtom x) e2 = (IsAtom (context_sub x e2)) | |
context_sub (Lam x e) e2 = (Lam x (context_sub e e2)) | |
context_sub (App f x) e2 = (App (context_sub f e2) (context_sub x e2)) | |
context_sub (Car x) e2 = (Car (context_sub x e2)) | |
context_sub (Cdr x) e2 = (Cdr (context_sub x e2)) | |
context_sub (Cons x y) e2 = (Cons (context_sub x e2) (context_sub y e2)) | |
-- A function to convert Var's to H's: | |
-- The second argument is some variable we're interested in | |
var_to_holes :: SExpr -> SExpr -> SExpr | |
var_to_holes H e2 = e2 | |
var_to_holes T e2 = T | |
var_to_holes F e2 = F | |
var_to_holes (Atom x) e2 = (Atom x) | |
var_to_holes e1@(Var x) e2 = case e2 == e1 of | |
True -> H | |
False -> e1 | |
var_to_holes Nil e2 = Nil | |
var_to_holes (IsAtom x) e2 = (IsAtom (var_to_holes x e2)) | |
var_to_holes (Lam x e) e2 = (Lam x (var_to_holes e e2)) | |
var_to_holes (App f x) e2 = (App (var_to_holes f e2) (var_to_holes x e2)) | |
var_to_holes (Car x) e2 = (Car (var_to_holes x e2)) | |
var_to_holes (Cdr x) e2 = (Cdr (var_to_holes x e2)) | |
var_to_holes (Cons x y) e2 = (Cons (var_to_holes x e2) (var_to_holes y e2)) | |
beta_reduce :: SExpr -> SExpr | |
beta_reduce (App (Lam v e1) e2) = context_sub (var_to_holes e1 (Var v)) e2 | |
-- A function to pattern-match a context expression with a complete expression, | |
-- returning Just (some difference) or Nothing (they are the same) | |
match_context :: SExpr -> SExpr -> Maybe SExpr | |
match_context ce e = loop ce e where | |
loop :: SExpr -> SExpr -> Maybe SExpr | |
loop (Cons x1 y1) (Cons x2 y2) = | |
loop x1 x2 `mplus` | |
loop y1 y2 | |
loop (IsAtom x) (IsAtom y) = do | |
r1 <- loop x y | |
return r1 | |
loop (Car x) (Car y) = do | |
r1 <- loop x y | |
return r1 | |
loop (Cdr x) (Cdr y) = do | |
r1 <- loop x y | |
return r1 | |
loop H e = return e | |
loop _ _ = Nothing | |
-- Finally: | |
-- The difference between two terms t1, t2: | |
-- Take holed version of t1, t1H: | |
-- For all 1-hole contexts ct2 in t2: | |
-- temp = context_sub ct2 t1h | |
-- case match_context temp t2 of | |
-- Just e -> check if substituting e in our current difference gets us t2, if so, return it (Just ct2 . ct1 . x -> e) | |
-- Nothing -> Nothing | |
diff :: SExpr -> SExpr -> SExpr | |
diff e1 e2 = let | |
e1h = var_to_holes e1 (Var "x") | |
e2cts = contexts e2 | |
results = do | |
e2c <- e2cts | |
let current = context_sub e2c e1h in | |
case match_context current e2 of | |
Just e -> let sub_candidate = context_sub current e in do | |
guard (sub_candidate == e2) | |
return (Lam "f" (Lam "x" (context_sub e2c (App (Var "f") e)))) | |
Nothing -> [] in | |
case results of | |
(x:xs) -> x | |
[] -> Nil | |
-- Next step of finding recursions: | |
-- 2. See whether the term difference is constant between terms | |
-- We'll need a function to determine whether each diff is equal to each other, in a list, skipping by n | |
-- It returns base cases, inductive steps | |
-- 3. If it is, define new higher-order functions dp, df which execute the difference | |
startFrom :: Int -> [a] -> [a] | |
startFrom n xs = case n of | |
0 -> xs | |
n -> startFrom (n - 1) (tail xs) | |
skipn :: Int -> [a] -> [(a, a)] | |
skipn n xs = zip xs (startFrom n xs) | |
find_base_recur :: (Eq a) => Int -> [a] -> ([a], a) | |
find_base_recur n xs = let | |
skipnxs = skipn n xs | |
base_cases = do | |
(x, y) <- skipnxs | |
guard (x /= y) | |
return x | |
inductive_candidates = do | |
(x, y) <- skipnxs | |
guard (x == y) | |
guard ((length (elemIndices y xs)) >= (2 * n)) -- the condition for inductive inference in the paper | |
return y in | |
(base_cases, head inductive_candidates) | |
-- And a function to take a list of pairs of (pred, frag) to their diffs, skipping by n | |
skipn_diff :: Int -> [SExpr] -> [SExpr] | |
skipn_diff n es = let | |
sknes = skipn n es in do | |
(x, y) <- sknes | |
return (diff x y) | |
-- For now, we just stick with n=1 | |
-- 4. Synthesize the recursive function | |
extend_with :: a -> Int -> [a] -> [a] | |
extend_with e n l = case n `compare` length l of | |
GT -> extend_with e n (l ++ [e]) | |
_ -> l | |
find_recursions :: [(SExpr, SExpr)] -> SExpr | |
find_recursions pfs = let | |
preds = map fst pfs | |
frags = map snd pfs | |
first_pred = head preds | |
first_frag = head frags | |
diff_preds = skipn_diff 1 preds | |
diff_frags = skipn_diff 1 frags | |
(basepreds, recpred) = find_base_recur 1 diff_preds | |
(basefrags, recfrag) = find_base_recur 1 diff_frags | |
-- Copy recpred, recfrag into basepreds, basefrags depending on which is shorter | |
basepreds' = extend_with recpred (length basefrags) basepreds | |
basefrags' = extend_with recfrag (length basepreds) basefrags | |
-- Next take care of all base case differences: | |
base_dpfs = zip basepreds' basefrags' | |
-- Create the base cases: | |
num_base = max (length basepreds) (length basefrags) | |
start = take (num_base + 1) pfs | |
-- Create the recursive call: | |
rec_call = [(T, beta_reduce (App recfrag (Var "self")))] | |
-- The term is then: | |
cond_body = start ++ rec_call in | |
-- Create the outer part (open-recursion): | |
(Lam "self" (Lam "x" (Cond cond_body))) | |
-- Some test cases | |
in1 = (Cons (Cons (Atom "a") (Atom "b")) (Cons (Atom "c") (Atom "d"))) | |
out1 = (Cons (Cons (Atom "d") (Atom "c")) (Cons (Atom "a") (Atom "b"))) | |
pairs_ex1 = let | |
a = Atom "a" | |
b = Atom "b" | |
c = Atom "c" | |
d = Atom "d" | |
e = Atom "e" in [ | |
((Cons a Nil), Nil), | |
((Cons a (Cons b Nil)), (Cons a Nil)), | |
((Cons a (Cons b (Cons c Nil))), (Cons a (Cons b Nil))), | |
((Cons a (Cons b (Cons c (Cons d Nil)))), (Cons a (Cons b (Cons c Nil)))), | |
((Cons a (Cons b (Cons c (Cons d (Cons e Nil))))), (Cons a (Cons b (Cons c (Cons d Nil)))))] | |
-- The first stage of Summers' program synthesis: | |
-- 1. Find traces between each input/output pair (make_fragment) | |
-- 2. Order inputs according to <= | |
-- 3. Derive the corresponding predicates (mk_classifier) | |
-- The result is a list of predicate/trace pairs that can be used as a | |
-- cond-expression that has the semantics of the function we want. | |
build_cond :: [(SExpr, SExpr)] -> [(SExpr, SExpr)] | |
build_cond pairs = let | |
in_out_trace = sortBy (\(i1, o1, t1) (i2, o2, t2) -> i1 `compare` i2) $ | |
map (\(i, o) -> (i, o, make_fragment i o)) pairs | |
in_out_trace_pred = zipWith (\(i1, o1, t1) (i2, o2, t2) -> (i1, o1, t1, mk_classifier i1 i2)) | |
(in_out_trace) (tail in_out_trace) in | |
map (\(i, o, t, p) -> (p, t)) in_out_trace_pred | |
pfs = build_cond pairs_ex1 | |
fs = map snd pfs | |
ps = map fst pfs | |
test1 = find_recursions pfs | |
e1 = fs !! 0 | |
e2 = fs !! 1 | |
e3 = fs !! 2 | |
e4 = fs !! 3 | |
e5 = fs !! 4 | |
p1 = ps !! 0 | |
p2 = ps !! 1 | |
p3 = ps !! 2 | |
p4 = ps !! 3 | |
p5 = ps !! 4 | |
main = do | |
putStrLn "Basic idea behind Philip Summers's analytic program synthesis: using the structure of input/output data to infer transformations" | |
putStrLn "Example input:" | |
putStrLn . show $ in1 | |
putStrLn "" | |
putStrLn "Example output:" | |
putStrLn . show $ out1 | |
putStrLn "" | |
putStrLn "Subexpressions of input:" | |
putStrLn . show $ all_subexprs in1 | |
putStrLn "" | |
putStrLn "Subexpressions of input + the function needed to travel to that part of the input (as a composition of car/cdr):" | |
putStrLn . show $ all_subexpr_func in1 | |
putStrLn "" | |
putStrLn "A complete example: Inferring remove-last" | |
putStrLn "The output expressed in terms of the input (_x_):" | |
putStrLn . show $ make_fragment in1 out1 | |
putStrLn "" | |
putStrLn "Input/output examples:" | |
forM_ pairs_ex1 $ \(i, o) -> do | |
putStrLn $ "input: " ++ show i ++ " output: " ++ show o | |
putStrLn "" | |
putStrLn "The derived predicates and fragments:" | |
forM_ (build_cond pairs_ex1) $ \(pred, frag) -> do | |
putStrLn $ show pred ++ " -> " ++ show frag | |
putStrLn "" | |
putStrLn "Term-differences of predicates:" | |
forM_ (skipn_diff 1 (map fst (build_cond pairs_ex1))) $ \pred -> do | |
putStrLn . show $ pred | |
putStrLn "" | |
putStrLn "Term-differences of fragments:" | |
forM_ (skipn_diff 1 (map snd (build_cond pairs_ex1))) $ \pred -> do | |
putStrLn . show $ pred | |
putStrLn "" | |
putStrLn "The resulting recursive program (before fixpoint operator) is obtained by inductively inferring that constant term differences over a long enough period will stay that way. Here, we keep the base cases up to when the term-differences of predicates and fragments are constant, then move on to a recursive call:" | |
putStrLn. show $ find_recursions (build_cond pairs_ex1) | |
-- We will need this eventually (some way to run the programs we synthesize). | |
-- normalize: Turns a composition of list functions into a list | |
normalize :: SExpr -> SExpr | |
normalize Nil = Nil | |
normalize T = T | |
normalize F = F | |
normalize (Atom s) = (Atom s) | |
normalize (Cons x rest) = (Cons (normalize x) (normalize rest)) | |
normalize (Car xs) = case xs of | |
(Cons e1 e2) -> normalize e1 | |
(Car ys) -> normalize (Car (normalize (Car (normalize ys)))) -- There must be a cleaner way to write this case... | |
(Cdr ys) -> normalize (Car (normalize (Cdr (normalize ys)))) | |
_ -> error "Can't car a non-list" | |
normalize (Cdr xs) = case xs of | |
(Cons e1 e2) -> normalize e2 | |
(Car ys) -> normalize (Cdr (normalize (Car (normalize ys)))) | |
(Cdr ys) -> normalize (Cdr (normalize (Cdr (normalize ys)))) | |
_ -> error "Can't cdr a non-list" | |
-- TODO: App, Lam, Var |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment