Skip to content

Instantly share code, notes, and snippets.

@VictorTaelin
Last active November 29, 2024 15:51
Show Gist options
  • Save VictorTaelin/7c4c69a1f07b5c668be613f1032e7d4e to your computer and use it in GitHub Desktop.
Save VictorTaelin/7c4c69a1f07b5c668be613f1032e7d4e to your computer and use it in GitHub Desktop.
A Superposed λ-Calculus Enumerator for Program Search
// This file is a mirror of:
// https://github.com/HigherOrderCO/HVM3/blob/main/book/lambda_enumerator_optimal.hvml
// An Optimal λ-Calculus Enumerator for Program Search
// ---------------------------------------------------
// This file shows a template on how to enumerate superposed terms in a
// higher-order language (here, the affine λ-Calculus) for proof search.
// Instead of generating the source syntax (like superposing all binary strings
// and parsing it as a binary λ-calculus), we create a superposition of all
// λ-terms *directly*, in a way that head constructors are emitted as soon as
// possible. This allows HVM to prune branches and backtrack efficiently as it
// computes the application of all λ-terms to some expression. The result is a
// reduction in the interactions needed to solve the equation:
// > (?X λt(t 1 2)) == λt(t 2 1)
// From ~32 million to just 76k, a 420x speedup*, which increases the harder the
// problem is. This generator works by synthesizing a λ-term in layers. On each
// layer, we either generate a lambda and extend a context, or select one of the
// variables in the context to return. When we select a variable, we will apply
// it to either 0, 1 or 2 other variables in the context (we don't generate
// terms with >2 arity apps here). This is not ideal; in a typed version, we
// would be able to tell the arity of the context variable, and generate the
// right amount of applications, without making a guess.
// * NOTE: we're actually able to bring the naive approach down to 1.7 million
// interactions. So, the numbers are:
// - Enum binary λC with loops (Haskell): 0.992s
// - Enum binary λC with sups (HVM): 0.026s (38x speedup)
// - Enum λC directly with sups (HVM): 0.0011s (862x speedup)
// The main contribution of this file is on the shape of the superposer. There
// are a few things that one must get right to achieve the desired effect.
// First, how do we split a linear context? It depends: when generating an
// application like `(f ?A ?B)`, we need to pass the context to `?A`, get the
// leftover, and pass it to `?B`, in such a way that `?A` and `?B` won't use the
// same variable twice. This happens within the same "universe". Yet, when
// making a choice, like "do we return a lambda or a variable here", we need to
// clone the linear context with the same label forked the universe itself,
// allowing a variable to be used more than once, as long as its occurrences are
// in different universes. Handling this correctly is very subtle, which is why
// this file can be useful for study.
// Second, how do we handle labels? As discussed recently on Discord:
// https://discord.com/channels/912426566838013994/915345481675186197/1311434500911403109
// We only need one label to fully enumerate all natural numbers. Yet, that
// doesn't work for binary trees. My conclusion is that we need to "fork" the
// label whenever we enumerate a constructor that branches; i.e., that has more
// than one field. Nats and Bits are safe because their constructors only have
// one field, but a binary Tree needs forking. To fork a label, we just mul by 2
// and add 0 or 1, and the seed has to be 1, so that forked branches never use
// the same label. We apply this here to the arity-2 app case.
// Third, how do we emit constructors as soon as possible, while still passing a
// context down? It is easy to accidentally mess this up by making the enum
// monadic. This will cause it to sequentialize its execution, meaning no ctor
// is emitted until the entire enumerator returns. That's a big problem, since
// we need head ctors to be available as soon as possible. That's how HVM is
// able to invalidate universes and backtrack. While this is a silly issue, it
// can spoil the whole thing, so I've elaborated it here:
// https://gist.github.com/VictorTaelin/fb798a5bd182f8c57dd302380f69777a
// The enumerator in this file is the simplest "template" enumerator that has
// everything a higher order language needs and is structured in a way that can
// be studied and extended with more sophisticate approaches, like types.
// EDIT: the dependently typed version has been pushed. It reduces the rewrite
// count to 3k, and greatly improves the enumerator shape.
data List {
#Nil
#Cons{head tail}
}
data Bits {
#O{pred}
#I{pred}
#E
}
data Term {
#Lam{bod}
#App{fun arg}
#Var{idx}
#Sub{val}
}
data pair {
#Pair{fst snd}
}
data Result {
#Result{src val}
}
data Maybe {
#None
#Some{value}
}
// Prelude
// -------
@if(c t f) = ~ c {
0: f
p: t
}
@when(c t) = ~ c {
0: *
p: t
}
@tail(xs) = ~ xs {
#Nil: *
#Cons{h t}: t
}
@and(a b) = ~ a !b {
0: 0
p: b
}
@unwrap(mb) = ~ mb {
#None: *
#Some{x}: x
}
@tm0(x) = !&0{a b}=x a
@tm1(x) = !&0{a b}=x b
// Stringification
// ---------------
@show_nat(nat) = ~nat {
0: λk #Cons{'Z' k}
p: λk #Cons{'S' (@show_nat(p) k)}
}
@show_dec(n r) =
! &{n n0} = n
! &{n n1} = n
! chr = (+ (% n 10) '0')
~ (< n0 10) !chr !r {
0: @show_dec((/ n1 10) #Cons{chr r})
t: #Cons{chr r}
}
@do_show_dec(n) = @show_dec(n #Nil)
@show_bits(bits) = ~bits {
#O{pred}: λk #Cons{'#' #Cons{'O' #Cons{'{' (@show_bits(pred) #Cons{'}' k})}}}
#I{pred}: λk #Cons{'#' #Cons{'I' #Cons{'{' (@show_bits(pred) #Cons{'}' k})}}}
#E: λk #Cons{'#' #Cons{'E' k}}
}
@do_show_bits(bits) = (@show_bits(bits) #Nil)
@show_term(term dep) = ~term !dep {
#Var{idx}: λk
@show_dec((- (- dep idx) 1) k)
#Lam{bod}: λk
!&{d0 d1}=dep
#Cons{'λ' (@show_term((bod #Var{d0}) (+ d1 1)) k)}
#App{fun arg}: λk
!&{d0 d1}=dep
#Cons{'(' (@show_term(fun d0)
#Cons{' ' (@show_term(arg d1)
#Cons{')' k})})}
#Sub{val}: *
}
@do_show_term(term) = (@show_term(term 0) #Nil)
// Equality
// --------
@eq(a b dep) = ~ a !b !dep {
#Lam{a_bod}: ~ b !dep {
#Lam{b_bod}:
!&{dep d0}=dep
!&{dep d1}=dep
!&{dep d2}=dep
@eq((a_bod #Var{d0}) (b_bod #Var{d1}) (+ 1 d2))
#App{b_fun b_arg}: 0
#Var{b_idx}: 0
#Sub{b_val}: *
}
#App{a_fun a_arg}: ~ b !dep {
#Lam{b_bod}: 0
#App{b_fun b_arg}:
!&{dep d0}=dep
!&{dep d1}=dep
@and(@eq(a_fun b_fun d0) @eq(a_arg b_arg d1))
#Var{b_idx}: 0
#Sub{b_val}: *
}
#Var{a_idx}: ~ b !dep {
#Lam{b_bod}: 0
#App{b_fun b_arg}: 0
#Var{b_idx}: (== a_idx b_idx)
#Sub{b_val}: *
}
#Sub{a_val}: *
}
// Evaluation
// ----------
@wnf(term) = ~ term {
#Lam{bod}: #Lam{bod}
#App{fun arg}: @wnf_app(@wnf(fun) arg)
#Var{idx}: #Var{idx}
#Sub{val}: #Sub{val}
}
@wnf_app(f x) = ~ f !x {
#Lam{bod}: @wnf((bod @wnf(x)))
#App{fun arg}: #App{#App{fun arg} x}
#Var{idx}: #App{#Var{idx} x}
#Sub{val}: #App{#Sub{val} x}
}
// Normalization
// -------------
@nf(term) = ~ @wnf(term) {
#Lam{bod}: #Lam{λx @nf((bod #Sub{x}))}
#App{fun arg}: #App{@nf(fun) @nf(arg)}
#Var{idx}: #Var{idx}
#Sub{val}: val
}
// Enumeration
// -----------
// Enumerates affine λ-terms.
// - lim: max context length (i.e., nested lambdas)
// - lab: superposition label. should be 1 initially.
// - ctx: the current scope. should be [] initially.
// If the binder limit has been reached, destroy this universe.
// Otherwise, make a choice.
// - A. We generate a fresh lambda.
// - B. We select a variable from context.
// Note that, every time we make a choice, we "fork" the current context by
// using DUP nodes with the same label that we used in the choice SUP node.
@all(&L &lim ctx) = ~&lim {
0: *
&lim:
!&L{ctxL ctxR} = ctx
&L{
@lam(&L &lim ctxL)
@ret(&L (+ &lim 1) ctxR λk(k))
}
}
// Generate a fresh lambda and extend the context with its variable.
@lam(&L &lim ctx) =
!&0{ctx bod} = @all(&L &lim #Cons{#Some{$x} ctx})
&0{@tail(ctx) #Lam{λ$x(bod)}}
// Return a variable from the context.
// If the context is empty, destroy this universe.
// Otherwise, make a choice.
// - A. We emit the head of the context, and apply it to things.
// - B. We keep the head of the context, and go to the next element.
@ret(&L &lim ctx rem) = ~ctx {
#Nil: *
#Cons{val ctx}:
!&L{remL remR} = rem
!&L{valL valR} = val
!&L{ctxL ctxR} = ctx
&L{
@app(&L &lim (remL #Cons{#None ctxL}) valL)
@ret(&L &lim ctxR λk(remR #Cons{valR k}))
}
}
// To apply a value to things, we will make a triple choice.
// - A. Just return it directly.
// - B. Apply it to 1 argument.
// - C. Apply it to 2 arguments.
// When we apply it to 2 arguments, as in `(App ?A ?B)`, we need to fork the
// label, so that DUPs/SUPs in `?A` and `?B` never use the same label.
@app(&L &lim ctx val) = ~ val {
#None: *
#Some{val}:
!&L{val val0} = val
!&L{val val1} = val
!&L{val val2} = val
!&L{ctx ctx0} = ctx
!&L{ctx ctx1} = ctx
!&L{ctx ctx2} = ctx
! arity_0 =
&0{ctx0 val0}
! arity_1 =
!&0{ctx1 argX} = @all(&L &lim ctx1)
&0{ctx1 #App{val1 argX}}
! arity_2 =
!&0{ctx2 argX} = @all((+(*&L 2) 0) &lim ctx2)
!&0{ctx2 argY} = @all((+(*&L 2) 1) &lim ctx2)
&0{ctx2 #App{#App{val2 argX} argY}}
&L{arity_0 &L{arity_1 arity_2}}
}
// Tests
// -----
//A= λt (t ^1 ^2)
//A= λ((Z SZ) SSZ)
@A = #Lam{λt #App{#App{t #Var{1}} #Var{2}}}
//B= λt (t ^2 ^1)
//B= λ((Z SSZ) SZ)
@B = #Lam{λt #App{#App{t #Var{2}} #Var{1}}}
//R= λx (x λa λb λt (t b a))
//R= λ(Z λλλ((SSSZ SSZ) SZ))
@R = #Lam{λx #App{x #Lam{λa #Lam{λb #Lam{λt #App{#App{t b} a}}}}}}
//X= (all terms)
@X = @tm1(@all(1 5 #Nil))
// Solves for `?X` in `(?X λt(t A B)) == λt(t B A)`.
// It finds `?X = λλ(1 λλ((2 0) 1))` in 76k interactions.
@main =
! solved = @eq(@nf(#App{@X @A}) @B 0)
@when(solved @do_show_term(@X))
@VictorTaelin
Copy link
Author

Below are the slower versions which parses and tests all binary λ-Calculus strings.

Haskell version (enumerating all BLC terms):

-- This is the Haskell version of the naive λ-Calculus enumerator, that just
-- generates all BLC strings and attempts one by one in a loop.

{-# LANGUAGE PatternSynonyms #-}

import Control.Monad (forM_, when)
import Data.Bits (testBit)
import System.Exit (exitSuccess)

data Bits = O Bits | I Bits | E deriving Show
data Term = Lam Term | App Term Term | Var Int deriving Show
data HTerm = HLam (HTerm -> HTerm) | HApp HTerm HTerm | HVar Int | HSub HTerm
data Pair a b = Pair a b deriving Show
data Result a r = Result a r deriving Show

-- Prelude
-- -------

bits :: Int -> Int -> Bits
bits 0 _ = E
bits n i
  | testBit i (n-1) = I (bits (n-1) i)
  | otherwise       = O (bits (n-1) i)

-- Parser
-- ------

parseTerm :: Bits -> Maybe (Result Bits Term)
parseTerm (O src) = do
  Result src nat <- parseInt src
  return $ Result src (Var nat)
parseTerm (I src) = case src of
  O src -> do
    Result src bod <- parseTerm src
    return $ Result src (Lam bod)
  I src -> do
    Result src fun <- parseTerm src
    Result src arg <- parseTerm src
    return $ Result src (App fun arg)
  E -> Nothing
parseTerm E = Nothing

parseInt :: Bits -> Maybe (Result Bits Int)
parseInt (O src) = Just $ Result src 0
parseInt (I src) = do
  Result src nat <- parseInt src
  return $ Result src (1 + nat)
parseInt E = Just $ Result E 0

doParseTerm :: Bits -> Maybe Term
doParseTerm src = do
  Result _ term <- parseTerm src
  return term

doParseHTerm :: Bits -> Maybe HTerm
doParseHTerm src = do
  Result _ term <- parseTerm src
  doBindTerm term

-- Binding
-- -------
-- NOTE: since Haskell doesn't have global variables ($x), we'll bind in two passes
-- The first pass just binds all variables
-- The second pass excludes non-affine terms

uses :: Term -> Int -> Int
uses (Lam bod)     idx = uses bod (idx + 1)
uses (App fun arg) idx = uses fun idx + uses arg idx
uses (Var n)       idx = if n == idx then 1 else 0

affine :: Term -> Bool
affine term = go term 0 where
  go (Lam bod)     dep = uses bod 0 <= 1 && go bod (dep + 1)
  go (App fun arg) dep = go fun dep && go arg dep
  go (Var n)       dep = n < dep

doBindTerm :: Term -> Maybe HTerm
doBindTerm term | affine term = Just (bindTerm term [])
doBindTerm term | otherwise   = Nothing

bindTerm :: Term -> [HTerm] -> HTerm
bindTerm (Lam bod)     ctx = HLam $ \x -> bindTerm bod (x : ctx)
bindTerm (App fun arg) ctx = HApp (bindTerm fun ctx) (bindTerm arg ctx)
bindTerm (Var idx)     ctx = get idx ctx

get :: Int -> [HTerm] -> HTerm
get 0 (x:_) = x
get n (_:t) = get (n-1) t
get _ []    = error "*"

-- Stringification
-- ---------------

showBits :: Bits -> String -> String
showBits (O pred) k = '#':'O':'{': showBits pred ('}':k)
showBits (I pred) k = '#':'I':'{': showBits pred ('}':k)
showBits E        k = '#':'E':k

doShowBits :: Bits -> String
doShowBits bits = showBits bits []

showTerm :: HTerm -> Int -> String -> String
showTerm (HVar idx)     dep k = show (dep - idx - 1) ++ k
showTerm (HLam bod)     dep k = 'λ' : showTerm (bod (HVar dep)) (dep+1) k
showTerm (HApp fun arg) dep k = '(' : showTerm fun dep (' ' : showTerm arg dep (')':k))
showTerm (HSub _)       _   _ = error "*"

doShowTerm :: HTerm -> String
doShowTerm term = showTerm term 0 []

-- Equality
-- --------

eq :: HTerm -> HTerm -> Int -> Bool
eq (HLam aBod)      (HLam bBod)      dep = eq (aBod (HVar dep)) (bBod (HVar dep)) (dep+1)
eq (HApp aFun aArg) (HApp bFun bArg) dep = eq aFun bFun dep && eq aArg bArg dep
eq (HVar aIdx)      (HVar bIdx)      _   = aIdx == bIdx
eq _                _                _   = False

-- Evaluation
-- ----------

wnf :: HTerm -> HTerm
wnf (HLam bod)     = HLam bod
wnf (HApp fun arg) = app (wnf fun) arg
wnf (HVar idx)     = HVar idx
wnf (HSub val)     = HSub val

app :: HTerm -> HTerm -> HTerm
app (HLam bod)     x = wnf (bod (wnf x))
app (HApp fun arg) x = HApp (HApp fun arg) x
app (HVar idx)     x = HApp (HVar idx) x
app (HSub val)     x = HApp (HSub val) x

-- Normalization
-- -------------

nf :: HTerm -> HTerm
nf term = case wnf term of
  HLam bod     -> HLam $ \x -> nf (bod (HSub x))
  HApp fun arg -> HApp (nf fun) (nf arg)
  HVar idx     -> HVar idx
  HSub val     -> val

-- Main
-- ----

a :: HTerm
a = HLam $ \t -> HApp (HApp t (HVar 1)) (HVar 2)

b :: HTerm
b = HLam $ \t -> HApp (HApp t (HVar 2)) (HVar 1)

r :: HTerm
r = HLam $ \x -> HApp x (HLam $ \a -> HLam $ \b -> HLam $ \t -> HApp (HApp t b) a)

-- Solve `?x` in `λaλb(?x λt(t a b)) == λaλbλt(t b a)`
main :: IO ()
main = forM_ [0..2^25-1] $ \i -> do
  let bs = bits 25 i
  case doParseHTerm bs of
    Nothing -> do
      return ()
    Just x -> do
      let solved = eq (nf (HApp x a)) b 0
      -- putStrLn $ show bs
      -- putStrLn $ doShowTerm x
      -- putStrLn $ doShowTerm (nf x)
      -- putStrLn $ show solved
      -- putStrLn $ "----------"
      when solved $ do
        putStrLn (doShowTerm x)
        exitSuccess

HVM version (superposing all BLC terms):

// This is the HVM version of the naive λ-Calculus enumerator. It superposes all
// binary λ-calculus strings, parses, and applies to the equation we want to
// solve. Despite the use of superpositions, this performs about the same as the
// Haskell version, since HVM is forced to enumerate all terms anyway, and not a
// lot of sharing is possible. This takes about 32 million interactions. A
// better approach is provided in the lambda_enumerator_optimal.hvml file, which
// brings this number down to just 72k interactions.
// UPDATE: actually - by just avoiding the issue depicted on:
// https://gist.github.com/VictorTaelin/fb798a5bd182f8c57dd302380f69777a
// We can bring this naive BLC enumerator down to 1.7m interactions. Not quite
// as fast as 72k, but this makes it ~37x faster than the Haskell version.

data List {
 #Nil
 #Cons{head tail}
}

data Bits {
 #O{pred}
 #I{pred}
 #E
}

data Term {
 #Lam{bod}
 #App{fun arg}
 #Var{idx}
 #Sub{val}
}

data pair {
 #Pair{fst snd}
}

data Maybe {
 #None
 #Some{value}
}

// Prelude
// -------

@if(c t f) = ~ c {
 0: f
 p: t
}

@when(c t) = ~ c {
 0: *
 p: t
}

@tail(xs) = ~ xs {
 #Nil: *
 #Cons{h t}: t
}

// Parsing
// -------

@do_parse_term(src) =
 ! &0{src term} = @parse_term(src)
 @do_bind_term(term)

@parse_term(src) = ~src {
 #O{src}:
   ! &0{src nat} = @parse_nat(src)
   &0{src #Var{nat}}
 #I{src}: ~src {
   #O{src}:
     ! &0{src bod} = @parse_term(src)
     &0{src #Lam{bod}}
   #I{src}:
     ! &0{src fun} = @parse_term(src)
     ! &0{src arg} = @parse_term(src)
     &0{src #App{fun arg}}
   #E: * 
 }
 #E: *
}

@parse_nat(src) = ~src {
 #O{src}: &0{src 0}
 #I{src}:
   ! &0{src nat} = @parse_nat(src)
   &0{src (+ 1 nat)}
 #E: &0{#E 0}
}

// Binding
// -------

@do_bind_term(term) =
 ! &0{ctx term} = @bind_term(term #Nil)
 term

@bind_term(term ctx) = ~term !ctx {
 #Lam{bod}:
   ! &0{ctx bod} = @bind_term(bod #Cons{#Some{$x} ctx})
   &0{@tail(ctx) #Lam{λ$x bod}}
 #App{fun arg}:
   ! &0{ctx fun} = @bind_term(fun ctx)
   ! &0{ctx arg} = @bind_term(arg ctx)
   &0{ctx #App{fun arg}}
 #Var{idx}: @get(idx ctx)
 #Sub{val}: *
}

@get(idx ctx) = ~ idx !ctx {
 0: ~ ctx {
   #Nil: *
   #Cons{h t}: ~ h {
     #None: *
     #Some{x}: &0{#Cons{#None t} x}
   }
 }
 p: ~ ctx {
   #Nil: *
   #Cons{h t}:
     ! &0{t x} = @get(p t)
     &0{#Cons{h t} x}
 }
}

// Stringification
// ---------------

@show_nat(nat) = ~nat { 
 0: λk #Cons{'Z' k}
 p: λk #Cons{'S' (@show_nat(p) k)}
}

@show_dec(n r) =
 ! &10000002{n n0} = n
 ! &10000002{n n1} = n
 ! chr = (+ (% n 10) '0')
 ~ (< n0 10) !chr !r {
   0: @show_dec((/ n1 10) #Cons{chr r})
   t: #Cons{chr r}
 }

@do_show_dec(n) = @show_dec(n #Nil)

@show_bits(bits) = ~bits {
 #O{pred}: λk #Cons{'#' #Cons{'O' #Cons{'{' (@show_bits(pred) #Cons{'}' k})}}}
 #I{pred}: λk #Cons{'#' #Cons{'I' #Cons{'{' (@show_bits(pred) #Cons{'}' k})}}}
 #E: λk #Cons{'#' #Cons{'E' k}}
}

@do_show_bits(bits) = (@show_bits(bits) #Nil)

@show_term(term dep) = ~term !dep {
 #Var{idx}: λk
   @show_dec((- (- dep idx) 1) k)
 #Lam{bod}: λk
   !&0{d0 d1}=dep
   #Cons{'λ' (@show_term((bod #Var{d0}) (+ d1 1)) k)}
 #App{fun arg}: λk
   !&0{d0 d1}=dep
   #Cons{'(' (@show_term(fun d0)
   #Cons{' ' (@show_term(arg d1)
   #Cons{')' k})})}
 #Sub{val}: *
}

@do_show_term(term) = (@show_term(term 0) #Nil)

// Equality
// --------

@eq(a b dep) = ~ a !b !dep {
 #Lam{a_bod}: ~ b !dep {
   #Lam{b_bod}:
     !&1{dep d0}=dep
     !&1{dep d1}=dep
     !&1{dep d2}=dep
     @eq((a_bod #Var{d0}) (b_bod #Var{d1}) (+ 1 d2))
   #App{b_fun b_arg}: 0
   #Var{b_idx}: 0
   #Sub{b_val}: *
 }
 #App{a_fun a_arg}: ~ b !dep {
   #Lam{b_bod}: 0
   #App{b_fun b_arg}:
     !&1{dep d0}=dep
     !&1{dep d1}=dep
     (& @eq(a_fun b_fun d0) @eq(a_arg b_arg d1))
   #Var{b_idx}: 0
   #Sub{b_val}: *
 }
 #Var{a_idx}: ~ b !dep {
   #Lam{b_bod}: 0
   #App{b_fun b_arg}: 0
   #Var{b_idx}: (== a_idx b_idx)
   #Sub{b_val}: *
 }
 #Sub{a_val}: *
}

// Evaluation
// ----------

@wnf(term) = ~ term { 
 #Lam{bod}: #Lam{bod}
 #App{fun arg}: @app(@wnf(fun) arg)
 #Var{idx}: #Var{idx}
 #Sub{val}: #Sub{val}
}

@app(f x) = ~ f !x {
 #Lam{bod}: @wnf((bod @wnf(x)))
 #App{fun arg}: #App{#App{fun arg} x}
 #Var{idx}: #App{#Var{idx} x}
 #Sub{val}: #App{#Sub{val} x}
}

// Normalization
// -------------

@nf(term) = ~ @wnf(term) {
 #Lam{bod}: #Lam{λx @nf((bod #Sub{x}))}
 #App{fun arg}: #App{@nf(fun) @nf(arg)}
 #Var{idx}: #Var{idx}
 #Sub{val}: val
}

// Enumeration
// -----------

// Enums all Bins of given size (label 1)
@all1(s) = ~s{
 0: #E
 p: !&2{p0 p1}=p &2{
   #O{@all1(p0)}
   #I{@all1(p1)}
 }
}

// Tests
// -----

//A= λt (t ^1 ^2)
//A= λ((Z SZ) SSZ)
@A = #Lam{λt #App{#App{t #Var{1}} #Var{2}}}

//B= λt (t ^2 ^1)
//B= λ((Z SSZ) SZ)
@B = #Lam{λt #App{#App{t #Var{2}} #Var{1}}}

//R= λx (x λa λb λt (t b a))
//R= λ(Z λλλ((SSSZ SSZ) SZ))
@R = #Lam{λx #App{x #Lam{λa #Lam{λb #Lam{λt #App{#App{t b} a}}}}}}

@X = @all1(25)

// Solve `?x` in `λaλb(?x λt(t a b)) == λaλbλt(t b a)`
@main =
 ! &5{x0 x1} = @do_parse_term(@X)
 ! solved    = @eq(@nf(#App{x0 @A}) @B 0) // (?x A) == B
 @when(solved @do_show_term(x1))

@VictorTaelin
Copy link
Author

VictorTaelin commented Nov 29, 2024

The dependently typed superposer, down to 3k interactions:

// Superposes dependently typed λ-terms. With it, solving:
//   (?X λt(t A B)) == λt(t B A)
// Where
//   ?X : ∀A. (∀P. A -> A -> P) -> (∀P. A -> A -> P)
// Is down to 3k interactions. Of course, that's not too surprising given there
// are only two functions of that type, but the real win is that now we only
// need to make a choice when selecting an element from context. Intros and
// elims follow directly from types, no need for choices / superpositions.

data List {
  #Nil
  #Cons{head tail}
}

data Bits {
  #O{pred}
  #I{pred}
  #E
}

data Term {
  #Var{idx}
  #Pol{bod}
  #All{inp bod}
  #Lam{bod}
  #App{fun arg}
  #U32
  #Num{val}
}

data Pair {
  #Pair{fst snd}
}

data Maybe {
  #None
  #Some{value}
}

// Prelude
// -------

@if(c t f) = ~ c {
  0: f
  p: t
}

@when(c t) = ~ c {
  0: *
  p: t
}

@tail(xs) = ~ xs {
  #Nil: *
  #Cons{h t}: t
}

@and(a b) = ~ a !b {
  0: 0
  p: b
}

@unwrap(mb) = ~mb {
  #None: *
  #Some{x}: x
}

@seq(str) = ~ str {
  #Nil: #Nil
  #Cons{h t}:
    !! h = h
    !! t = @seq(t)
    #Cons{h t}
}

@tm0(x) = !&0{a b}=x a
@tm1(x) = !&0{a b}=x b

// Stringification
// ---------------

@show_nat(nat) = ~nat { 
  0: λk #Cons{'Z' k}
  p: λk #Cons{'S' (@show_nat(p) k)}
}

@show_dec(n r) =
  ! &{n n0} = n
  ! &{n n1} = n
  ! chr = (+ (% n 10) '0')
  ~ (< n0 10) !chr !r {
    0: @show_dec((/ n1 10) #Cons{chr r})
    t: #Cons{chr r}
  }

@do_show_dec(n) = @show_dec(n #Nil)

@show_bits(bits) = ~bits {
  #O{pred}: λk #Cons{'#' #Cons{'O' #Cons{'{' (@show_bits(pred) #Cons{'}' k})}}}
  #I{pred}: λk #Cons{'#' #Cons{'I' #Cons{'{' (@show_bits(pred) #Cons{'}' k})}}}
  #E: λk #Cons{'#' #Cons{'E' k}}
}

@do_show_bits(bits) = (@show_bits(bits) #Nil)

@show_term(term dep) = ~term !dep {
  #Var{idx}: λk
    @show_dec(idx k)
  #Pol{bod}: λk
    !&{dep d0}=dep
    !&{dep d1}=dep
    #Cons{'∀' (@show_term((bod #Var{d0}) (+ d1 1)) k)}
  #All{inp bod}: λk
    !&{dep d0}=dep
    !&{dep d1}=dep
    !&{dep d2}=dep
    #Cons{'Π'
    #Cons{'('
    (@show_term(inp d0)
    #Cons{')'
    (@show_term((bod #Var{d1}) (+ d2 1))
    k)})}}
  #Lam{bod}: λk
    !&{d0 d1}=dep
    #Cons{'λ' (@show_term((bod #Var{d0}) (+ d1 1)) k)}
  #App{fun arg}: λk
    !&{d0 d1}=dep
    #Cons{'(' (@show_term(fun d0)
    #Cons{' ' (@show_term(arg d1)
    #Cons{')' k})})}
  #U32: λk
    #Cons{'U' k}
  #Num{val}: λk
    #Cons{'#' @show_dec(val k)}
}

@do_show_term(term) = (@show_term(term 0) #Nil)

// Equality
// --------

@eq(a b dep) = ~ @wnf(a) !b !dep {
  #Var{a_idx}: ~ @wnf(b) !dep {
    #Var{b_idx}: (== a_idx b_idx)
    #Pol{b_bod}: 0
    #All{b_inp b_bod}: 0
    #Lam{b_bod}: 0
    #App{b_fun b_arg}: 0
    #U32: 0
    #Num{b_val}: 0
  }
  #Pol{a_bod}: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}:
      !&{dep d0}=dep
      !&{dep d1}=dep
      @eq((a_bod #Var{d0}) (b_bod #Var{d1}) (+ dep 1))
    #All{b_inp b_bod}: 0
    #Lam{b_bod}: 0
    #App{b_fun b_arg}: 0
    #U32: 0
    #Num{b_val}: 0
  }
  #All{a_inp a_bod}: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}: 0
    #All{b_inp b_bod}:
      !&{dep d0}=dep
      !&{dep d1}=dep
      !&{dep d2}=dep
      @and(@eq(a_inp b_inp d0) @eq((a_bod #Var{d1}) (b_bod #Var{d2}) (+ dep 1)))
    #Lam{b_bod}: 0
    #App{b_fun b_arg}: 0
    #U32: 0
    #Num{b_val}: 0
  }
  #Lam{a_bod}: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}: 0
    #All{b_inp b_bod}: 0
    #Lam{b_bod}:
      !&{dep d0}=dep
      !&{dep d1}=dep
      @eq((a_bod #Var{d0}) (b_bod #Var{d1}) (+ dep 1))
    #App{b_fun b_arg}: 0
    #U32: 0
    #Num{b_val}: 0
  }
  #App{a_fun a_arg}: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}: 0
    #All{b_inp b_bod}: 0
    #Lam{b_bod}: 0
    #App{b_fun b_arg}:
      !&{dep d0}=dep
      !&{dep d1}=dep
      @and(@eq(a_fun b_fun d0) @eq(a_arg b_arg d1))
    #U32: 0
    #Num{b_val}: 0
  }
  #U32: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}: 0
    #All{b_inp b_bod}: 0
    #Lam{b_bod}: 0
    #App{b_fun b_arg}: 0
    #U32: 1
    #Num{b_val}: 0
  }
  #Num{a_val}: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}: 0
    #All{b_inp b_bod}: 0
    #Lam{b_bod}: 0
    #App{b_fun b_arg}: 0
    #U32: 0
    #Num{b_val}: (== a_val b_val)
  }
}

// Evaluation
// ----------

@wnf(term) = ~ term { 
  #Var{idx}: #Var{idx}
  #Pol{bod}: #Pol{bod}
  #All{inp bod}: #All{inp bod}
  #Lam{bod}: #Lam{bod}
  #App{fun arg}: @wnf_app(@wnf(fun) arg)
  #U32: #U32
  #Num{val}: #Num{val}
}

@wnf_app(f x) = ~ f !x {
  #Var{idx}: #App{#Var{idx} x}
  #Pol{bod}: #App{#Pol{bod} x}
  #All{inp bod}: #App{#All{inp bod} x}
  #Lam{bod}: @wnf((bod @wnf(x)))
  #App{fun arg}: #App{#App{fun arg} x}
  #U32: #U32
  #Num{val}: #App{#Num{val} x}
}

// Enumeration
// -----------

@all(&L typ &dep ctx) =
  @intr(&L typ &dep ctx)

@intr(&L typ &dep ctx) =
  ~ typ !ctx {
    #All{t_inp t_bod}: 
      !&0{ctx bod} = @all(&L (t_bod #Var{&dep}) (+ &dep 1) #Cons{#Some{&0{$x t_inp}} ctx})
      &0{@tail(ctx) #Lam{λ$x(bod)}}
    #Pol{t_bod}:
      @intr(&L (t_bod #Var{&dep}) (+ &dep 1) ctx)
    #U32:
      @pick(&L #U32 &dep ctx λk(k))
    #Var{idx}:
      @pick(&L #Var{idx} &dep ctx λk(k))
    #App{fun arg}:
      @pick(&L #App{fun arg} &dep ctx λk(k))
    #Lam{bod}: *
    #Num{val}: *
  }

@pick(&L typ &dep ctx rem) = 
  ~ctx {
    #Nil: *
    #Cons{ann ctx}:
      !&L{typL typR} = typ
      !&L{remL remR} = rem
      !&L{annL annR} = ann
      !&L{ctxL ctxR} = ctx
      &L{
        @elim(&L typL &dep (remL #Cons{#None ctxL}) annL)
        @pick(&L typR &dep ctxR λk(remR #Cons{annR k}))
      }
  }

@elim(&L typ &dep ctx ann) = ~ann {
  #None: *
  #Some{ann}:
    ! &0{v t} = ann
    ~ t !typ !ctx !v {
      #Pol{t_bod}:
        ! &{typ0 typ1} = typ
        @elim(&L typ0 &dep ctx #Some{&0{v (t_bod typ1)}})
      #All{t_inp t_bod}:
        ! &0{ctx arg}  = @all((+(*&L 2)1) t_inp &dep ctx)
        ! &{arg0 arg1} = arg
        @elim((+(*&L 2)0) typ &dep ctx #Some{&0{#App{v arg0} (t_bod arg1)}})
      #U32: 
        @when(@eq(typ #U32 &dep) &0{ctx v})
      #Var{idx}:
        @when(@eq(typ #Var{idx} &dep) &0{ctx v})
      #App{fun arg}:
        @when(@eq(typ #App{fun arg} &dep) &0{ctx v})
      #Lam{bod}: *
      #Num{val}: *
    }
}

// Tests
// -----

// T0 = Π(x:U32) Π(y:U32) Π(z:U32) U32
@T0 =
  #All{#U32 λx
  #All{#U32 λy
  #All{#U32 λz
  #U32}}}

// CBool = ∀P Π(x:P) Π(y:P) P
@CBool =
  #Pol{λp
  !&{p p0}=p
  !&{p p1}=p
  !&{p p2}=p
  #All{p0 λx
  #All{p1 λy
  p2}}}

// Tup(A B) = ∀P Π(p: Π(x:A) Π(y:B) P) P
@Tup(A B) =
  #Pol{λp
  !&{p p0}=p
  !&{p p1}=p
  #All{
    #All{A λx
    #All{B λy
    p0}} λp
  p1}}

// TupF = ∀A Π(x: (Tup A A)) (Tup A A)
@TupF =
  #Pol{λA
  #Pol{λB
  !&{A A0}=A
  !&{A A1}=A
  !&{A A2}=A
  !&{A A3}=A
  #All{@Tup(A0 A1) λx
  @Tup(A2 A3)}}}

//A= λt(t 1 2)
@A = #Lam{λt #App{#App{t #Num{1}} #Num{2}}}

//B= λt(t 2 1)
@B = #Lam{λt #App{#App{t #Num{2}} #Num{1}}}

//R= λx(x λaλbλt(t b a))
@R = #Lam{λx #App{x #Lam{λa #Lam{λb #Lam{λt #App{#App{t b} a}}}}}}

// X : ∀A. (Tup A A) -> (Tup A A) = <all terms>
@T = @TupF
@X = @tm1(@all(1 @T 0 #Nil))

// Solves for `?X` in `(?X λt(t A B)) == λt(t B A)`.
// It finds `?X = λλ(0 λλ((1 3) 2))` in 3k interactions.
@main = @when(@eq(#App{@X @A} @B 0) @do_show_term(@X))
//@main = @X

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment