Skip to content

Instantly share code, notes, and snippets.

@VictorTaelin
Last active November 29, 2024 15:51
Show Gist options
  • Save VictorTaelin/7c4c69a1f07b5c668be613f1032e7d4e to your computer and use it in GitHub Desktop.
Save VictorTaelin/7c4c69a1f07b5c668be613f1032e7d4e to your computer and use it in GitHub Desktop.
A Superposed λ-Calculus Enumerator for Program Search
// This file is a mirror of:
// https://github.com/HigherOrderCO/HVM3/blob/main/book/lambda_enumerator_optimal.hvml
// An Optimal λ-Calculus Enumerator for Program Search
// ---------------------------------------------------
// This file shows a template on how to enumerate superposed terms in a
// higher-order language (here, the affine λ-Calculus) for proof search.
// Instead of generating the source syntax (like superposing all binary strings
// and parsing it as a binary λ-calculus), we create a superposition of all
// λ-terms *directly*, in a way that head constructors are emitted as soon as
// possible. This allows HVM to prune branches and backtrack efficiently as it
// computes the application of all λ-terms to some expression. The result is a
// reduction in the interactions needed to solve the equation:
// > (?X λt(t 1 2)) == λt(t 2 1)
// From ~32 million to just 76k, a 420x speedup*, which increases the harder the
// problem is. This generator works by synthesizing a λ-term in layers. On each
// layer, we either generate a lambda and extend a context, or select one of the
// variables in the context to return. When we select a variable, we will apply
// it to either 0, 1 or 2 other variables in the context (we don't generate
// terms with >2 arity apps here). This is not ideal; in a typed version, we
// would be able to tell the arity of the context variable, and generate the
// right amount of applications, without making a guess.
// * NOTE: we're actually able to bring the naive approach down to 1.7 million
// interactions. So, the numbers are:
// - Enum binary λC with loops (Haskell): 0.992s
// - Enum binary λC with sups (HVM): 0.026s (38x speedup)
// - Enum λC directly with sups (HVM): 0.0011s (862x speedup)
// The main contribution of this file is on the shape of the superposer. There
// are a few things that one must get right to achieve the desired effect.
// First, how do we split a linear context? It depends: when generating an
// application like `(f ?A ?B)`, we need to pass the context to `?A`, get the
// leftover, and pass it to `?B`, in such a way that `?A` and `?B` won't use the
// same variable twice. This happens within the same "universe". Yet, when
// making a choice, like "do we return a lambda or a variable here", we need to
// clone the linear context with the same label forked the universe itself,
// allowing a variable to be used more than once, as long as its occurrences are
// in different universes. Handling this correctly is very subtle, which is why
// this file can be useful for study.
// Second, how do we handle labels? As discussed recently on Discord:
// https://discord.com/channels/912426566838013994/915345481675186197/1311434500911403109
// We only need one label to fully enumerate all natural numbers. Yet, that
// doesn't work for binary trees. My conclusion is that we need to "fork" the
// label whenever we enumerate a constructor that branches; i.e., that has more
// than one field. Nats and Bits are safe because their constructors only have
// one field, but a binary Tree needs forking. To fork a label, we just mul by 2
// and add 0 or 1, and the seed has to be 1, so that forked branches never use
// the same label. We apply this here to the arity-2 app case.
// Third, how do we emit constructors as soon as possible, while still passing a
// context down? It is easy to accidentally mess this up by making the enum
// monadic. This will cause it to sequentialize its execution, meaning no ctor
// is emitted until the entire enumerator returns. That's a big problem, since
// we need head ctors to be available as soon as possible. That's how HVM is
// able to invalidate universes and backtrack. While this is a silly issue, it
// can spoil the whole thing, so I've elaborated it here:
// https://gist.github.com/VictorTaelin/fb798a5bd182f8c57dd302380f69777a
// The enumerator in this file is the simplest "template" enumerator that has
// everything a higher order language needs and is structured in a way that can
// be studied and extended with more sophisticate approaches, like types.
// EDIT: the dependently typed version has been pushed. It reduces the rewrite
// count to 3k, and greatly improves the enumerator shape.
data List {
#Nil
#Cons{head tail}
}
data Bits {
#O{pred}
#I{pred}
#E
}
data Term {
#Lam{bod}
#App{fun arg}
#Var{idx}
#Sub{val}
}
data pair {
#Pair{fst snd}
}
data Result {
#Result{src val}
}
data Maybe {
#None
#Some{value}
}
// Prelude
// -------
@if(c t f) = ~ c {
0: f
p: t
}
@when(c t) = ~ c {
0: *
p: t
}
@tail(xs) = ~ xs {
#Nil: *
#Cons{h t}: t
}
@and(a b) = ~ a !b {
0: 0
p: b
}
@unwrap(mb) = ~ mb {
#None: *
#Some{x}: x
}
@tm0(x) = !&0{a b}=x a
@tm1(x) = !&0{a b}=x b
// Stringification
// ---------------
@show_nat(nat) = ~nat {
0: λk #Cons{'Z' k}
p: λk #Cons{'S' (@show_nat(p) k)}
}
@show_dec(n r) =
! &{n n0} = n
! &{n n1} = n
! chr = (+ (% n 10) '0')
~ (< n0 10) !chr !r {
0: @show_dec((/ n1 10) #Cons{chr r})
t: #Cons{chr r}
}
@do_show_dec(n) = @show_dec(n #Nil)
@show_bits(bits) = ~bits {
#O{pred}: λk #Cons{'#' #Cons{'O' #Cons{'{' (@show_bits(pred) #Cons{'}' k})}}}
#I{pred}: λk #Cons{'#' #Cons{'I' #Cons{'{' (@show_bits(pred) #Cons{'}' k})}}}
#E: λk #Cons{'#' #Cons{'E' k}}
}
@do_show_bits(bits) = (@show_bits(bits) #Nil)
@show_term(term dep) = ~term !dep {
#Var{idx}: λk
@show_dec((- (- dep idx) 1) k)
#Lam{bod}: λk
!&{d0 d1}=dep
#Cons{'λ' (@show_term((bod #Var{d0}) (+ d1 1)) k)}
#App{fun arg}: λk
!&{d0 d1}=dep
#Cons{'(' (@show_term(fun d0)
#Cons{' ' (@show_term(arg d1)
#Cons{')' k})})}
#Sub{val}: *
}
@do_show_term(term) = (@show_term(term 0) #Nil)
// Equality
// --------
@eq(a b dep) = ~ a !b !dep {
#Lam{a_bod}: ~ b !dep {
#Lam{b_bod}:
!&{dep d0}=dep
!&{dep d1}=dep
!&{dep d2}=dep
@eq((a_bod #Var{d0}) (b_bod #Var{d1}) (+ 1 d2))
#App{b_fun b_arg}: 0
#Var{b_idx}: 0
#Sub{b_val}: *
}
#App{a_fun a_arg}: ~ b !dep {
#Lam{b_bod}: 0
#App{b_fun b_arg}:
!&{dep d0}=dep
!&{dep d1}=dep
@and(@eq(a_fun b_fun d0) @eq(a_arg b_arg d1))
#Var{b_idx}: 0
#Sub{b_val}: *
}
#Var{a_idx}: ~ b !dep {
#Lam{b_bod}: 0
#App{b_fun b_arg}: 0
#Var{b_idx}: (== a_idx b_idx)
#Sub{b_val}: *
}
#Sub{a_val}: *
}
// Evaluation
// ----------
@wnf(term) = ~ term {
#Lam{bod}: #Lam{bod}
#App{fun arg}: @wnf_app(@wnf(fun) arg)
#Var{idx}: #Var{idx}
#Sub{val}: #Sub{val}
}
@wnf_app(f x) = ~ f !x {
#Lam{bod}: @wnf((bod @wnf(x)))
#App{fun arg}: #App{#App{fun arg} x}
#Var{idx}: #App{#Var{idx} x}
#Sub{val}: #App{#Sub{val} x}
}
// Normalization
// -------------
@nf(term) = ~ @wnf(term) {
#Lam{bod}: #Lam{λx @nf((bod #Sub{x}))}
#App{fun arg}: #App{@nf(fun) @nf(arg)}
#Var{idx}: #Var{idx}
#Sub{val}: val
}
// Enumeration
// -----------
// Enumerates affine λ-terms.
// - lim: max context length (i.e., nested lambdas)
// - lab: superposition label. should be 1 initially.
// - ctx: the current scope. should be [] initially.
// If the binder limit has been reached, destroy this universe.
// Otherwise, make a choice.
// - A. We generate a fresh lambda.
// - B. We select a variable from context.
// Note that, every time we make a choice, we "fork" the current context by
// using DUP nodes with the same label that we used in the choice SUP node.
@all(&L &lim ctx) = ~&lim {
0: *
&lim:
!&L{ctxL ctxR} = ctx
&L{
@lam(&L &lim ctxL)
@ret(&L (+ &lim 1) ctxR λk(k))
}
}
// Generate a fresh lambda and extend the context with its variable.
@lam(&L &lim ctx) =
!&0{ctx bod} = @all(&L &lim #Cons{#Some{$x} ctx})
&0{@tail(ctx) #Lam{λ$x(bod)}}
// Return a variable from the context.
// If the context is empty, destroy this universe.
// Otherwise, make a choice.
// - A. We emit the head of the context, and apply it to things.
// - B. We keep the head of the context, and go to the next element.
@ret(&L &lim ctx rem) = ~ctx {
#Nil: *
#Cons{val ctx}:
!&L{remL remR} = rem
!&L{valL valR} = val
!&L{ctxL ctxR} = ctx
&L{
@app(&L &lim (remL #Cons{#None ctxL}) valL)
@ret(&L &lim ctxR λk(remR #Cons{valR k}))
}
}
// To apply a value to things, we will make a triple choice.
// - A. Just return it directly.
// - B. Apply it to 1 argument.
// - C. Apply it to 2 arguments.
// When we apply it to 2 arguments, as in `(App ?A ?B)`, we need to fork the
// label, so that DUPs/SUPs in `?A` and `?B` never use the same label.
@app(&L &lim ctx val) = ~ val {
#None: *
#Some{val}:
!&L{val val0} = val
!&L{val val1} = val
!&L{val val2} = val
!&L{ctx ctx0} = ctx
!&L{ctx ctx1} = ctx
!&L{ctx ctx2} = ctx
! arity_0 =
&0{ctx0 val0}
! arity_1 =
!&0{ctx1 argX} = @all(&L &lim ctx1)
&0{ctx1 #App{val1 argX}}
! arity_2 =
!&0{ctx2 argX} = @all((+(*&L 2) 0) &lim ctx2)
!&0{ctx2 argY} = @all((+(*&L 2) 1) &lim ctx2)
&0{ctx2 #App{#App{val2 argX} argY}}
&L{arity_0 &L{arity_1 arity_2}}
}
// Tests
// -----
//A= λt (t ^1 ^2)
//A= λ((Z SZ) SSZ)
@A = #Lam{λt #App{#App{t #Var{1}} #Var{2}}}
//B= λt (t ^2 ^1)
//B= λ((Z SSZ) SZ)
@B = #Lam{λt #App{#App{t #Var{2}} #Var{1}}}
//R= λx (x λa λb λt (t b a))
//R= λ(Z λλλ((SSSZ SSZ) SZ))
@R = #Lam{λx #App{x #Lam{λa #Lam{λb #Lam{λt #App{#App{t b} a}}}}}}
//X= (all terms)
@X = @tm1(@all(1 5 #Nil))
// Solves for `?X` in `(?X λt(t A B)) == λt(t B A)`.
// It finds `?X = λλ(1 λλ((2 0) 1))` in 76k interactions.
@main =
! solved = @eq(@nf(#App{@X @A}) @B 0)
@when(solved @do_show_term(@X))
@VictorTaelin
Copy link
Author

VictorTaelin commented Nov 29, 2024

The dependently typed superposer, down to 3k interactions:

// Superposes dependently typed λ-terms. With it, solving:
//   (?X λt(t A B)) == λt(t B A)
// Where
//   ?X : ∀A. (∀P. A -> A -> P) -> (∀P. A -> A -> P)
// Is down to 3k interactions. Of course, that's not too surprising given there
// are only two functions of that type, but the real win is that now we only
// need to make a choice when selecting an element from context. Intros and
// elims follow directly from types, no need for choices / superpositions.

data List {
  #Nil
  #Cons{head tail}
}

data Bits {
  #O{pred}
  #I{pred}
  #E
}

data Term {
  #Var{idx}
  #Pol{bod}
  #All{inp bod}
  #Lam{bod}
  #App{fun arg}
  #U32
  #Num{val}
}

data Pair {
  #Pair{fst snd}
}

data Maybe {
  #None
  #Some{value}
}

// Prelude
// -------

@if(c t f) = ~ c {
  0: f
  p: t
}

@when(c t) = ~ c {
  0: *
  p: t
}

@tail(xs) = ~ xs {
  #Nil: *
  #Cons{h t}: t
}

@and(a b) = ~ a !b {
  0: 0
  p: b
}

@unwrap(mb) = ~mb {
  #None: *
  #Some{x}: x
}

@seq(str) = ~ str {
  #Nil: #Nil
  #Cons{h t}:
    !! h = h
    !! t = @seq(t)
    #Cons{h t}
}

@tm0(x) = !&0{a b}=x a
@tm1(x) = !&0{a b}=x b

// Stringification
// ---------------

@show_nat(nat) = ~nat { 
  0: λk #Cons{'Z' k}
  p: λk #Cons{'S' (@show_nat(p) k)}
}

@show_dec(n r) =
  ! &{n n0} = n
  ! &{n n1} = n
  ! chr = (+ (% n 10) '0')
  ~ (< n0 10) !chr !r {
    0: @show_dec((/ n1 10) #Cons{chr r})
    t: #Cons{chr r}
  }

@do_show_dec(n) = @show_dec(n #Nil)

@show_bits(bits) = ~bits {
  #O{pred}: λk #Cons{'#' #Cons{'O' #Cons{'{' (@show_bits(pred) #Cons{'}' k})}}}
  #I{pred}: λk #Cons{'#' #Cons{'I' #Cons{'{' (@show_bits(pred) #Cons{'}' k})}}}
  #E: λk #Cons{'#' #Cons{'E' k}}
}

@do_show_bits(bits) = (@show_bits(bits) #Nil)

@show_term(term dep) = ~term !dep {
  #Var{idx}: λk
    @show_dec(idx k)
  #Pol{bod}: λk
    !&{dep d0}=dep
    !&{dep d1}=dep
    #Cons{'∀' (@show_term((bod #Var{d0}) (+ d1 1)) k)}
  #All{inp bod}: λk
    !&{dep d0}=dep
    !&{dep d1}=dep
    !&{dep d2}=dep
    #Cons{'Π'
    #Cons{'('
    (@show_term(inp d0)
    #Cons{')'
    (@show_term((bod #Var{d1}) (+ d2 1))
    k)})}}
  #Lam{bod}: λk
    !&{d0 d1}=dep
    #Cons{'λ' (@show_term((bod #Var{d0}) (+ d1 1)) k)}
  #App{fun arg}: λk
    !&{d0 d1}=dep
    #Cons{'(' (@show_term(fun d0)
    #Cons{' ' (@show_term(arg d1)
    #Cons{')' k})})}
  #U32: λk
    #Cons{'U' k}
  #Num{val}: λk
    #Cons{'#' @show_dec(val k)}
}

@do_show_term(term) = (@show_term(term 0) #Nil)

// Equality
// --------

@eq(a b dep) = ~ @wnf(a) !b !dep {
  #Var{a_idx}: ~ @wnf(b) !dep {
    #Var{b_idx}: (== a_idx b_idx)
    #Pol{b_bod}: 0
    #All{b_inp b_bod}: 0
    #Lam{b_bod}: 0
    #App{b_fun b_arg}: 0
    #U32: 0
    #Num{b_val}: 0
  }
  #Pol{a_bod}: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}:
      !&{dep d0}=dep
      !&{dep d1}=dep
      @eq((a_bod #Var{d0}) (b_bod #Var{d1}) (+ dep 1))
    #All{b_inp b_bod}: 0
    #Lam{b_bod}: 0
    #App{b_fun b_arg}: 0
    #U32: 0
    #Num{b_val}: 0
  }
  #All{a_inp a_bod}: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}: 0
    #All{b_inp b_bod}:
      !&{dep d0}=dep
      !&{dep d1}=dep
      !&{dep d2}=dep
      @and(@eq(a_inp b_inp d0) @eq((a_bod #Var{d1}) (b_bod #Var{d2}) (+ dep 1)))
    #Lam{b_bod}: 0
    #App{b_fun b_arg}: 0
    #U32: 0
    #Num{b_val}: 0
  }
  #Lam{a_bod}: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}: 0
    #All{b_inp b_bod}: 0
    #Lam{b_bod}:
      !&{dep d0}=dep
      !&{dep d1}=dep
      @eq((a_bod #Var{d0}) (b_bod #Var{d1}) (+ dep 1))
    #App{b_fun b_arg}: 0
    #U32: 0
    #Num{b_val}: 0
  }
  #App{a_fun a_arg}: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}: 0
    #All{b_inp b_bod}: 0
    #Lam{b_bod}: 0
    #App{b_fun b_arg}:
      !&{dep d0}=dep
      !&{dep d1}=dep
      @and(@eq(a_fun b_fun d0) @eq(a_arg b_arg d1))
    #U32: 0
    #Num{b_val}: 0
  }
  #U32: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}: 0
    #All{b_inp b_bod}: 0
    #Lam{b_bod}: 0
    #App{b_fun b_arg}: 0
    #U32: 1
    #Num{b_val}: 0
  }
  #Num{a_val}: ~ @wnf(b) !dep {
    #Var{b_idx}: 0
    #Pol{b_bod}: 0
    #All{b_inp b_bod}: 0
    #Lam{b_bod}: 0
    #App{b_fun b_arg}: 0
    #U32: 0
    #Num{b_val}: (== a_val b_val)
  }
}

// Evaluation
// ----------

@wnf(term) = ~ term { 
  #Var{idx}: #Var{idx}
  #Pol{bod}: #Pol{bod}
  #All{inp bod}: #All{inp bod}
  #Lam{bod}: #Lam{bod}
  #App{fun arg}: @wnf_app(@wnf(fun) arg)
  #U32: #U32
  #Num{val}: #Num{val}
}

@wnf_app(f x) = ~ f !x {
  #Var{idx}: #App{#Var{idx} x}
  #Pol{bod}: #App{#Pol{bod} x}
  #All{inp bod}: #App{#All{inp bod} x}
  #Lam{bod}: @wnf((bod @wnf(x)))
  #App{fun arg}: #App{#App{fun arg} x}
  #U32: #U32
  #Num{val}: #App{#Num{val} x}
}

// Enumeration
// -----------

@all(&L typ &dep ctx) =
  @intr(&L typ &dep ctx)

@intr(&L typ &dep ctx) =
  ~ typ !ctx {
    #All{t_inp t_bod}: 
      !&0{ctx bod} = @all(&L (t_bod #Var{&dep}) (+ &dep 1) #Cons{#Some{&0{$x t_inp}} ctx})
      &0{@tail(ctx) #Lam{λ$x(bod)}}
    #Pol{t_bod}:
      @intr(&L (t_bod #Var{&dep}) (+ &dep 1) ctx)
    #U32:
      @pick(&L #U32 &dep ctx λk(k))
    #Var{idx}:
      @pick(&L #Var{idx} &dep ctx λk(k))
    #App{fun arg}:
      @pick(&L #App{fun arg} &dep ctx λk(k))
    #Lam{bod}: *
    #Num{val}: *
  }

@pick(&L typ &dep ctx rem) = 
  ~ctx {
    #Nil: *
    #Cons{ann ctx}:
      !&L{typL typR} = typ
      !&L{remL remR} = rem
      !&L{annL annR} = ann
      !&L{ctxL ctxR} = ctx
      &L{
        @elim(&L typL &dep (remL #Cons{#None ctxL}) annL)
        @pick(&L typR &dep ctxR λk(remR #Cons{annR k}))
      }
  }

@elim(&L typ &dep ctx ann) = ~ann {
  #None: *
  #Some{ann}:
    ! &0{v t} = ann
    ~ t !typ !ctx !v {
      #Pol{t_bod}:
        ! &{typ0 typ1} = typ
        @elim(&L typ0 &dep ctx #Some{&0{v (t_bod typ1)}})
      #All{t_inp t_bod}:
        ! &0{ctx arg}  = @all((+(*&L 2)1) t_inp &dep ctx)
        ! &{arg0 arg1} = arg
        @elim((+(*&L 2)0) typ &dep ctx #Some{&0{#App{v arg0} (t_bod arg1)}})
      #U32: 
        @when(@eq(typ #U32 &dep) &0{ctx v})
      #Var{idx}:
        @when(@eq(typ #Var{idx} &dep) &0{ctx v})
      #App{fun arg}:
        @when(@eq(typ #App{fun arg} &dep) &0{ctx v})
      #Lam{bod}: *
      #Num{val}: *
    }
}

// Tests
// -----

// T0 = Π(x:U32) Π(y:U32) Π(z:U32) U32
@T0 =
  #All{#U32 λx
  #All{#U32 λy
  #All{#U32 λz
  #U32}}}

// CBool = ∀P Π(x:P) Π(y:P) P
@CBool =
  #Pol{λp
  !&{p p0}=p
  !&{p p1}=p
  !&{p p2}=p
  #All{p0 λx
  #All{p1 λy
  p2}}}

// Tup(A B) = ∀P Π(p: Π(x:A) Π(y:B) P) P
@Tup(A B) =
  #Pol{λp
  !&{p p0}=p
  !&{p p1}=p
  #All{
    #All{A λx
    #All{B λy
    p0}} λp
  p1}}

// TupF = ∀A Π(x: (Tup A A)) (Tup A A)
@TupF =
  #Pol{λA
  #Pol{λB
  !&{A A0}=A
  !&{A A1}=A
  !&{A A2}=A
  !&{A A3}=A
  #All{@Tup(A0 A1) λx
  @Tup(A2 A3)}}}

//A= λt(t 1 2)
@A = #Lam{λt #App{#App{t #Num{1}} #Num{2}}}

//B= λt(t 2 1)
@B = #Lam{λt #App{#App{t #Num{2}} #Num{1}}}

//R= λx(x λaλbλt(t b a))
@R = #Lam{λx #App{x #Lam{λa #Lam{λb #Lam{λt #App{#App{t b} a}}}}}}

// X : ∀A. (Tup A A) -> (Tup A A) = <all terms>
@T = @TupF
@X = @tm1(@all(1 @T 0 #Nil))

// Solves for `?X` in `(?X λt(t A B)) == λt(t B A)`.
// It finds `?X = λλ(0 λλ((1 3) 2))` in 3k interactions.
@main = @when(@eq(#App{@X @A} @B 0) @do_show_term(@X))
//@main = @X

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment