Skip to content

Instantly share code, notes, and snippets.

@jdmonty
Forked from LFY/summers.hs
Created October 4, 2022 19:12
Show Gist options
  • Save jdmonty/32232a3e1d948a331782f803c591c1f1 to your computer and use it in GitHub Desktop.
Save jdmonty/32232a3e1d948a331782f803c591c1f1 to your computer and use it in GitHub Desktop.
Partial implementation of LISP program synthesis
-- Analytic inductive programming, according to survey paper:
-- Kitzelmann, Emanuel. Inductive Programming: A Survey of Program Synthesis Techniques
-- Other related work:
-- D.R. Smith. The synthesis of LISP programs from examples: A survey.
-- P.D. Summers. A Methodology for LISP program construction from examples.
-- Author: Lingfeng Yang
import Data.List
import Control.Monad
import Control.Monad.Writer
-- Our language: LISP
-- Goal: Derive recursive programs that transform lists using car, cdr, cons, nil
-- We are given several input/output pairs. Our goal is to learn a function of the form
-- (define (f x)
-- (cond
-- (p1 (f1 x))
-- (p2 (f2 x))
-- ...
-- (else (f (d x)))))
-- i.e., a conditional with several predicates p1,p2 ... and corresponding
-- traces f1 f2... plus a recursive call to f, with some reduction d of the
-- input (to ensure termination).
-- The algorithm is divided into 3 stages:
-- 1. Learning traces from each input/output pair
-- 2. Learning predicates that distinguish each input to the proper trace
-- 3. Learning recursive calls
-- We start by defining S-expressions as an algebraic data type
-- H is for expressing one (or multiple)-hole contexts.
-- The other terms correspond to their counterparts in Lisp.
data SExpr =
-- Atoms
Nil | T | F | Atom String | -- Atoms
-- Lists
Cons SExpr SExpr | -- Lists
-- List _traversals_
Car SExpr | Cdr SExpr | -- List _traversals_
-- Function abstraction, variables, and applications
Lam String SExpr | Var String | App SExpr SExpr |
-- Built-in predicates
IsAtom SExpr |
-- Conditional expression
Cond [(SExpr, SExpr)] |
-- Hole (for manipulating SExprs)
H
deriving Eq
-- Our first interpreter: Printing terms
instance Show SExpr where
show (Cons x y) = "(" ++ show x ++ " . " ++ show y ++ ")"
show T = "T"
show F = "F"
show Nil = "()"
show (Atom s) = s
show (Car x) = "(car " ++ show x ++ ")"
show (Cdr x) = "(cdr " ++ show x ++ ")"
show (Var s) = "_" ++ s ++ "_"
show (Lam s e) = "(lambda (_" ++ s ++ "_) " ++ show e ++ ")"
show (App e1 e2) = "(" ++ show e1 ++ " " ++ show e2 ++ ")"
show (IsAtom e) = "(atom? " ++ show e ++ ")"
show (Cond pfs) = "(cond\n" ++ (snd (runWriter (write pfs))) ++ ")" where
write :: [(SExpr, SExpr)] -> Writer String ()
write pfs = do
forM_ pfs $ \(pred, frag) -> do
tell $ "(" ++ show pred ++ " " ++ show frag ++ ")\n"
show H = "___"
-- The first step: building program traces (aka fragments)
-- The paper outlines the following strategy for obtaining a single program
-- that works on a single input/output pair:
-- 1. For each 'largest' subexpression of the output that can be found in the
-- input, take the associated basic function (composition of car/cdr), then
-- that serves as a part of the final program, because we can apply that basic
-- function to the input to obtain that part of the output.
-- 2. The final S-expression is then obtained by substituing sub-S-expressions
-- of the output with the basic functions of interest. The final S-expression
-- is a _function_ which, given the input, produces the output.
-- This operates on a single input, output pair. The paper mentions that we
-- first take each element in the input subexpression and pair it up with a
-- composition of Car/Cdr; each subexpression in the output that occurs in the input
-- may be then be replaced as a traversal applied to the input.
-- Function to obtain subexpressions
all_subexprs :: SExpr -> [SExpr]
all_subexprs expr = loop [] expr where
loop xs (Atom x) = (Atom x):xs
loop xs Nil = (Nil:xs)
loop xs T = T:xs
loop xs F = F:xs
loop xs cc@(Cons x y) = let
l1 = loop [] x
l2 = loop [] y
partial = l1 ++ l2 ++ xs in
(cc:partial)
varx = Var "x"
lamx e = Lam "x" e
-- Function to obtain subexpressions + traces to obtain that subexpression
all_subexpr_func :: SExpr -> [(SExpr, SExpr)]
all_subexpr_func expr = loop [] varx expr where
loop xs v (Atom x) = (Atom x, v):xs
loop xs v Nil = (Nil, v):xs
loop xs v T = (T, v):xs
loop xs v F = (F, v):xs
loop xs v cc@(Cons x y) = let
l1 = loop [] (Car v) x
l2 = loop [] (Cdr v) y
partial = l1 ++ l2 ++ xs in
(cc, v):partial
-- Remark:
-- This is exactly building up _zippers_ of a data structure (here,
-- S-expressions)
-- car/cdr would be the traversal functions
make_fragment :: SExpr -> SExpr -> SExpr
make_fragment input output = let
input_subexpr_funcs = all_subexpr_func input
loop :: [(SExpr, SExpr)] -> SExpr -> SExpr
loop sfs (Atom x) = case lookup (Atom x) sfs of
Just expr -> expr
Nothing -> error "Atom of output not found in input"
loop sfs Nil = case lookup Nil sfs of
Just expr -> expr
Nothing -> error "Nil of output not found in input"
loop sfs T = case lookup T sfs of
Just expr -> expr
Nothing -> error "T of output not found in input"
loop sfs F = case lookup F sfs of
Just expr -> expr
Nothing -> error "F of output not found in input"
loop sfs (Cons x y) = case lookup (Cons x y) sfs of
Just expr -> expr
Nothing -> let
res1 = loop sfs x
res2 = loop sfs y in
(Cons res1 res2)
in
loop input_subexpr_funcs output
-- The next phase: building predicates.
-- 1. Define a partial order on forms: any atom <= others, and (a . b) <= (c .
-- d) iff a <= c and b <= d.
instance Ord SExpr where
(Atom x) <= y = True
Nil <= y = True
T <= y = True
F <= y = True
(Cons a b) <= (Cons c d) = (a <= c) && (b <= d)
-- 2. Obtain a total order on the input examples.
-- 3. Then, derive a classifier for each input example, so that we associat the
-- correct output function with the correct input.
-- I could not find the actual algorithm to derive classifiers, but the Summers
-- paper provides enough of a skech that I think this is what it does:
-- Derive the classifier as follows: Given x <= y, we know that y is 'more
-- complicated' than x. Traverse x and y using the same traversal moves, in
-- parallel. At some sub-expression of x, x will be an atom and y will be
-- another cons cell. We then record atom(t1 . t2 . \ldots . tn (x)) as the
-- classifier.
-- Convenience function: _form_, to turn all atoms of an S-expression into the
-- same value (here, T), so we can compare just the structure
form :: SExpr -> SExpr
form (Cons x y) = (Cons (form x) (form y))
form (Atom x) = T
form Nil = T
form T = T
form F = T
form _ = error "Forms other than cons cells not supported"
-- The actual algorithm to form a predicate.
-- precondition: the two s-exprs are such that e1 <= e2
mk_classifier :: SExpr -> SExpr -> SExpr
mk_classifier e1 e2 = IsAtom $ head $ loop [] varx (form e1) (form e2) where
loop :: [SExpr] -> SExpr -> SExpr -> SExpr -> [SExpr]
loop xs v (Cons x1 y1) (Cons x2 y2) = let
l1 = loop [] (Car v) x1 x2
l2 = loop [] (Cdr v) y1 y2
partial = l1 ++ l2 ++ xs in
partial
-- This means e1 <= e2; i.e., there exists some subpart of e1 that is an atom,
-- but corresponds to a cons cell in e2. In this case this is (one of) the
-- answers we want and we _return_ it (List monad)
loop xs v T (Cons x2 y2) = [v]
-- Exercise: prove (or disprove) that this pattern-match can _never_ happen if e1 <= e2
-- loop xs v (Cons x1 y1) T = [v]
-- This is a 'tie' between e1, e2. We are not interested in building these results
loop xs v T T = []
-- The next step:
-- Finding recursions by finding recurrence relations.
-- Steps: From a properly ordered list of [(predicate, fragment)] pairs:
-- 1. Calculate the _difference_ between each pair
-- The difference function is the most complicated part. This can be broken
-- down into several stages. I'm sure there's probably a cleaner way to do
-- this, but long story short, by implementing this one has implemented a good
-- chunk of a Prolog interpreter.
-- First: A function to find one-hole contexts
-- A zipper data type: a particular node in the AST, plus the tree above and
-- below that now
type SExprZip = (SExpr, SExpr)
-- Our traversal functions
-- The type is [SExprZip] because it may fail
downleft :: SExprZip -> [SExprZip]
downleft (aboveh, below) = do
(newlhs, newrhs) <- hdownleft below
return (context_sub aboveh newlhs, newrhs) where
hdownleft Nil = mzero
hdownleft T = mzero
hdownleft F = mzero
hdownleft (Atom x) = mzero
hdownleft (Var x) = mzero
hdownleft (IsAtom x) = return ((IsAtom H), x)
hdownleft (Car x) = return ((Car H), x)
hdownleft (Cdr x) = return ((Cdr H), x)
hdownleft (Cons x y) = return ((Cons H y), x)
downright :: SExprZip -> [SExprZip]
downright (aboveh, below) = do
(newlhs, newrhs) <- hdownright below
return (context_sub aboveh newlhs, newrhs) where
hdownright Nil = mzero
hdownright T = mzero
hdownright F = mzero
hdownright (Atom x) = mzero
hdownright (Var x) = mzero
hdownright (IsAtom x) = mzero
hdownright (Car x) = mzero
hdownright (Cdr x) = mzero
hdownright (Cons x y) = return ((Cons x H), y)
-- Going backwards (up the tree)
unzipSExpr :: SExprZip -> SExpr
unzipSExpr (above, below) = context_sub above below
-- Now we have our function to obtain all 1-hole contexts
-- The 1-hole contexts of a term are all revisions of that term with a subterm
-- removed and replaced by H
-- Example: The contexts of ((a . b) . (c . d)) are
-- H, (H . (c . d), ((H . b) . (c . d)), ((a . b) . H), etc.
curr_context :: SExprZip -> SExpr
curr_context = fst
contexts :: SExpr -> [SExpr]
contexts e = let
start_zip = return (H, e)
loop :: [SExprZip] -> [SExprZip]
loop current_zip = do
z <- current_zip
loop (downleft z) `mplus` loop (downright z) `mplus` return z in
map curr_context $ loop start_zip
-- A function to substitute a s-expr in a s-expr with H's (replace H with that expression):
context_sub :: SExpr -> SExpr -> SExpr
context_sub H e2 = e2
context_sub T e2 = T
context_sub F e2 = F
context_sub (Atom x) e2 = (Atom x)
context_sub (Var x) e2 = (Var x)
context_sub Nil e2 = Nil
context_sub (IsAtom x) e2 = (IsAtom (context_sub x e2))
context_sub (Lam x e) e2 = (Lam x (context_sub e e2))
context_sub (App f x) e2 = (App (context_sub f e2) (context_sub x e2))
context_sub (Car x) e2 = (Car (context_sub x e2))
context_sub (Cdr x) e2 = (Cdr (context_sub x e2))
context_sub (Cons x y) e2 = (Cons (context_sub x e2) (context_sub y e2))
-- A function to convert Var's to H's:
-- The second argument is some variable we're interested in
var_to_holes :: SExpr -> SExpr -> SExpr
var_to_holes H e2 = e2
var_to_holes T e2 = T
var_to_holes F e2 = F
var_to_holes (Atom x) e2 = (Atom x)
var_to_holes e1@(Var x) e2 = case e2 == e1 of
True -> H
False -> e1
var_to_holes Nil e2 = Nil
var_to_holes (IsAtom x) e2 = (IsAtom (var_to_holes x e2))
var_to_holes (Lam x e) e2 = (Lam x (var_to_holes e e2))
var_to_holes (App f x) e2 = (App (var_to_holes f e2) (var_to_holes x e2))
var_to_holes (Car x) e2 = (Car (var_to_holes x e2))
var_to_holes (Cdr x) e2 = (Cdr (var_to_holes x e2))
var_to_holes (Cons x y) e2 = (Cons (var_to_holes x e2) (var_to_holes y e2))
beta_reduce :: SExpr -> SExpr
beta_reduce (App (Lam v e1) e2) = context_sub (var_to_holes e1 (Var v)) e2
-- A function to pattern-match a context expression with a complete expression,
-- returning Just (some difference) or Nothing (they are the same)
match_context :: SExpr -> SExpr -> Maybe SExpr
match_context ce e = loop ce e where
loop :: SExpr -> SExpr -> Maybe SExpr
loop (Cons x1 y1) (Cons x2 y2) =
loop x1 x2 `mplus`
loop y1 y2
loop (IsAtom x) (IsAtom y) = do
r1 <- loop x y
return r1
loop (Car x) (Car y) = do
r1 <- loop x y
return r1
loop (Cdr x) (Cdr y) = do
r1 <- loop x y
return r1
loop H e = return e
loop _ _ = Nothing
-- Finally:
-- The difference between two terms t1, t2:
-- Take holed version of t1, t1H:
-- For all 1-hole contexts ct2 in t2:
-- temp = context_sub ct2 t1h
-- case match_context temp t2 of
-- Just e -> check if substituting e in our current difference gets us t2, if so, return it (Just ct2 . ct1 . x -> e)
-- Nothing -> Nothing
diff :: SExpr -> SExpr -> SExpr
diff e1 e2 = let
e1h = var_to_holes e1 (Var "x")
e2cts = contexts e2
results = do
e2c <- e2cts
let current = context_sub e2c e1h in
case match_context current e2 of
Just e -> let sub_candidate = context_sub current e in do
guard (sub_candidate == e2)
return (Lam "f" (Lam "x" (context_sub e2c (App (Var "f") e))))
Nothing -> [] in
case results of
(x:xs) -> x
[] -> Nil
-- Next step of finding recursions:
-- 2. See whether the term difference is constant between terms
-- We'll need a function to determine whether each diff is equal to each other, in a list, skipping by n
-- It returns base cases, inductive steps
-- 3. If it is, define new higher-order functions dp, df which execute the difference
startFrom :: Int -> [a] -> [a]
startFrom n xs = case n of
0 -> xs
n -> startFrom (n - 1) (tail xs)
skipn :: Int -> [a] -> [(a, a)]
skipn n xs = zip xs (startFrom n xs)
find_base_recur :: (Eq a) => Int -> [a] -> ([a], a)
find_base_recur n xs = let
skipnxs = skipn n xs
base_cases = do
(x, y) <- skipnxs
guard (x /= y)
return x
inductive_candidates = do
(x, y) <- skipnxs
guard (x == y)
guard ((length (elemIndices y xs)) >= (2 * n)) -- the condition for inductive inference in the paper
return y in
(base_cases, head inductive_candidates)
-- And a function to take a list of pairs of (pred, frag) to their diffs, skipping by n
skipn_diff :: Int -> [SExpr] -> [SExpr]
skipn_diff n es = let
sknes = skipn n es in do
(x, y) <- sknes
return (diff x y)
-- For now, we just stick with n=1
-- 4. Synthesize the recursive function
extend_with :: a -> Int -> [a] -> [a]
extend_with e n l = case n `compare` length l of
GT -> extend_with e n (l ++ [e])
_ -> l
find_recursions :: [(SExpr, SExpr)] -> SExpr
find_recursions pfs = let
preds = map fst pfs
frags = map snd pfs
first_pred = head preds
first_frag = head frags
diff_preds = skipn_diff 1 preds
diff_frags = skipn_diff 1 frags
(basepreds, recpred) = find_base_recur 1 diff_preds
(basefrags, recfrag) = find_base_recur 1 diff_frags
-- Copy recpred, recfrag into basepreds, basefrags depending on which is shorter
basepreds' = extend_with recpred (length basefrags) basepreds
basefrags' = extend_with recfrag (length basepreds) basefrags
-- Next take care of all base case differences:
base_dpfs = zip basepreds' basefrags'
-- Create the base cases:
num_base = max (length basepreds) (length basefrags)
start = take (num_base + 1) pfs
-- Create the recursive call:
rec_call = [(T, beta_reduce (App recfrag (Var "self")))]
-- The term is then:
cond_body = start ++ rec_call in
-- Create the outer part (open-recursion):
(Lam "self" (Lam "x" (Cond cond_body)))
-- Some test cases
in1 = (Cons (Cons (Atom "a") (Atom "b")) (Cons (Atom "c") (Atom "d")))
out1 = (Cons (Cons (Atom "d") (Atom "c")) (Cons (Atom "a") (Atom "b")))
pairs_ex1 = let
a = Atom "a"
b = Atom "b"
c = Atom "c"
d = Atom "d"
e = Atom "e" in [
((Cons a Nil), Nil),
((Cons a (Cons b Nil)), (Cons a Nil)),
((Cons a (Cons b (Cons c Nil))), (Cons a (Cons b Nil))),
((Cons a (Cons b (Cons c (Cons d Nil)))), (Cons a (Cons b (Cons c Nil)))),
((Cons a (Cons b (Cons c (Cons d (Cons e Nil))))), (Cons a (Cons b (Cons c (Cons d Nil)))))]
-- The first stage of Summers' program synthesis:
-- 1. Find traces between each input/output pair (make_fragment)
-- 2. Order inputs according to <=
-- 3. Derive the corresponding predicates (mk_classifier)
-- The result is a list of predicate/trace pairs that can be used as a
-- cond-expression that has the semantics of the function we want.
build_cond :: [(SExpr, SExpr)] -> [(SExpr, SExpr)]
build_cond pairs = let
in_out_trace = sortBy (\(i1, o1, t1) (i2, o2, t2) -> i1 `compare` i2) $
map (\(i, o) -> (i, o, make_fragment i o)) pairs
in_out_trace_pred = zipWith (\(i1, o1, t1) (i2, o2, t2) -> (i1, o1, t1, mk_classifier i1 i2))
(in_out_trace) (tail in_out_trace) in
map (\(i, o, t, p) -> (p, t)) in_out_trace_pred
pfs = build_cond pairs_ex1
fs = map snd pfs
ps = map fst pfs
test1 = find_recursions pfs
e1 = fs !! 0
e2 = fs !! 1
e3 = fs !! 2
e4 = fs !! 3
e5 = fs !! 4
p1 = ps !! 0
p2 = ps !! 1
p3 = ps !! 2
p4 = ps !! 3
p5 = ps !! 4
main = do
putStrLn "Basic idea behind Philip Summers's analytic program synthesis: using the structure of input/output data to infer transformations"
putStrLn "Example input:"
putStrLn . show $ in1
putStrLn ""
putStrLn "Example output:"
putStrLn . show $ out1
putStrLn ""
putStrLn "Subexpressions of input:"
putStrLn . show $ all_subexprs in1
putStrLn ""
putStrLn "Subexpressions of input + the function needed to travel to that part of the input (as a composition of car/cdr):"
putStrLn . show $ all_subexpr_func in1
putStrLn ""
putStrLn "A complete example: Inferring remove-last"
putStrLn "The output expressed in terms of the input (_x_):"
putStrLn . show $ make_fragment in1 out1
putStrLn ""
putStrLn "Input/output examples:"
forM_ pairs_ex1 $ \(i, o) -> do
putStrLn $ "input: " ++ show i ++ " output: " ++ show o
putStrLn ""
putStrLn "The derived predicates and fragments:"
forM_ (build_cond pairs_ex1) $ \(pred, frag) -> do
putStrLn $ show pred ++ " -> " ++ show frag
putStrLn ""
putStrLn "Term-differences of predicates:"
forM_ (skipn_diff 1 (map fst (build_cond pairs_ex1))) $ \pred -> do
putStrLn . show $ pred
putStrLn ""
putStrLn "Term-differences of fragments:"
forM_ (skipn_diff 1 (map snd (build_cond pairs_ex1))) $ \pred -> do
putStrLn . show $ pred
putStrLn ""
putStrLn "The resulting recursive program (before fixpoint operator) is obtained by inductively inferring that constant term differences over a long enough period will stay that way. Here, we keep the base cases up to when the term-differences of predicates and fragments are constant, then move on to a recursive call:"
putStrLn. show $ find_recursions (build_cond pairs_ex1)
-- We will need this eventually (some way to run the programs we synthesize).
-- normalize: Turns a composition of list functions into a list
normalize :: SExpr -> SExpr
normalize Nil = Nil
normalize T = T
normalize F = F
normalize (Atom s) = (Atom s)
normalize (Cons x rest) = (Cons (normalize x) (normalize rest))
normalize (Car xs) = case xs of
(Cons e1 e2) -> normalize e1
(Car ys) -> normalize (Car (normalize (Car (normalize ys)))) -- There must be a cleaner way to write this case...
(Cdr ys) -> normalize (Car (normalize (Cdr (normalize ys))))
_ -> error "Can't car a non-list"
normalize (Cdr xs) = case xs of
(Cons e1 e2) -> normalize e2
(Car ys) -> normalize (Cdr (normalize (Car (normalize ys))))
(Cdr ys) -> normalize (Cdr (normalize (Cdr (normalize ys))))
_ -> error "Can't cdr a non-list"
-- TODO: App, Lam, Var
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment