Skip to content

Instantly share code, notes, and snippets.

@roconnor
Last active May 14, 2024 15:32
Show Gist options
  • Save roconnor/52375c3a81c927c5a8234be858f9830b to your computer and use it in GitHub Desktop.
Save roconnor/52375c3a81c927c5a8234be858f9830b to your computer and use it in GitHub Desktop.
Memo Trie for funcitons over binary trees in Coq (no support for memoizing structually recursive functions though)
(* In relation to https://cstheory.stackexchange.com/questions/40631/how-can-you-build-a-coinductive-memoization-table-for-recursive-functions-over-b *)
Require Import Vectors.Vector.
Import VectorNotations.
Set Primitive Projections.
Set Implicit Arguments.
Inductive binTree : Set :=
| Leaf : binTree
| Branch : binTree -> binTree -> binTree.
Inductive filterZero (A : nat -> Type) : nat -> Type :=
| FilterZero : filterZero A 0
| NonZero : forall m, (A m) -> filterZero A (S m).
Definition getNonZero A n (w : filterZero A (S n)) : A n.
inversion_clear w as [|m a].
apply a.
Defined.
(* memoVBinTree A n is intended to be isomorphic to Vector.t BinTree n -> A *)
CoInductive memoVBinTree (A : Type) (n : nat) : Type :=
{ extract : if n then A else unit
; memoLeaf : filterZero (memoVBinTree A) n
; memoBranch : filterZero (fun _ => memoVBinTree A (S n)) n
}.
(* memoBinTree A is intended to be isomorphic to BinTree -> A *)
Definition memoBinTree A := memoVBinTree A 1.
Definition branch {A R n} (k : memoVBinTree A n -> memoVBinTree A (S (S n)) -> R)
(t : memoVBinTree A (S n)) : R :=
k (getNonZero (memoLeaf t)) (getNonZero (memoBranch t)).
CoFixpoint map {A B n} (f : A -> B) (t : memoVBinTree A n) : memoVBinTree B n :=
{| extract := (if n return (memoVBinTree A n -> if n then B else unit)
then fun t => f (extract t)
else fun _ => tt) t
; memoLeaf := (match n return (memoVBinTree A n -> filterZero (memoVBinTree B) n) with
| 0 => fun _ => FilterZero _
| S n0 => fun t => NonZero _ _ (map f (getNonZero (memoLeaf t)))
end) t
; memoBranch := (if n return (memoVBinTree A n -> filterZero (fun _ => memoVBinTree B (S n)) n)
then fun _ => FilterZero _
else fun t => NonZero _ _ (map f (getNonZero (memoBranch t)))) t
|}.
CoFixpoint zip {A B C n m} (f : A -> B -> C)
(ta : memoVBinTree A n) (tb : memoVBinTree B m) : memoVBinTree C (n + m) :=
{| extract := (if n return (memoVBinTree A n -> if (n + m) then C else unit)
then fun ta => (if m return (memoVBinTree B m -> if (0 + m) then C else unit)
then fun tb => f (extract ta) (extract tb)
else fun _ => tt) tb
else fun _ => tt) ta
; memoLeaf := (match n return (memoVBinTree A n -> filterZero (memoVBinTree C) (n + m)) with
| 0 => fun ta => memoLeaf (map (f (extract ta)) tb)
| S n0 => fun ta => NonZero _ _ (zip f (getNonZero (memoLeaf ta)) tb)
end) ta
; memoBranch := (if n return (memoVBinTree A n -> filterZero (fun _ => memoVBinTree C (S n + m)) (n + m))
then fun ta => memoBranch (map (f (extract ta)) tb)
else fun ta => NonZero _ _ (zip f (getNonZero (memoBranch ta)) tb)) ta
|}.
CoFixpoint split {A} n m (t : memoVBinTree A (n + m)) : memoVBinTree (memoVBinTree A m) n :=
{| extract := (if n return (memoVBinTree A (n + m) -> if n then (memoVBinTree A m) else unit)
then fun t => t
else fun _ => tt) t
; memoLeaf := (match n return (memoVBinTree A (n + m) -> filterZero (memoVBinTree (memoVBinTree A m)) n) with
| 0 => fun _ => FilterZero _
| S n0 => fun t => NonZero _ _ (split _ _ (getNonZero (memoLeaf t)))
end) t
; memoBranch := (match n return (memoVBinTree A (n + m) -> filterZero (fun _ => memoVBinTree (memoVBinTree A m) (S n)) n) with
| 0 => fun _ => FilterZero _
| S n0 => fun t => NonZero _ _ (split (S (S n0)) m (getNonZero (memoBranch t)))
end) t
|}.
CoFixpoint trie {A n} (f : Vector.t binTree n -> A) : memoVBinTree A n :=
{| extract := (if n return ((Vector.t binTree n -> A) -> if n then A else unit)
then fun f => f (nil binTree)
else fun _ => tt) f
; memoLeaf := (match n return ((Vector.t binTree n -> A) -> filterZero (memoVBinTree A) n) with
| 0 => fun _ => FilterZero _
| S n0 => fun f => NonZero _ _ (trie (fun v => f (cons _ Leaf _ v)))
end) f
; memoBranch := (if n return ((Vector.t binTree n -> A) -> filterZero (fun _ => memoVBinTree A (S n)) n)
then fun _ => FilterZero _
else fun f => NonZero _ _ (trie (fun v => f (cons _ (Branch (hd v) (hd (tl v))) _ (tl (tl v)))))) f
|}.
Fixpoint untrie (t : binTree) : forall {A}, (memoBinTree A) -> A :=
fun A => branch (fun l b =>
match t with
| Leaf => extract l
| Branch tl tr => untrie tl (map (untrie tr) (split 1 1 b))
end).
(*
* Example.
*)
Require Import ZArith.
(* This operation could stand to be a bit more expensive *)
Definition op (x y : Z) : Z := Z.modulo ((x + y + 1)*(x + y)/2 + y) 18446744073709551557%Z.
Fixpoint foldOp (t : binTree) : Z :=
match t with
| Leaf => 2147483647%Z
| Branch tl tr => op (foldOp tl) (foldOp tr)
end.
Fixpoint fullTree (n : nat) : binTree :=
match n with
| 0 => Leaf
| (S n0) => let rec := fullTree n0 in Branch rec rec
end.
(* memoTrie, but with no sharing for recursive calls of foldOp *)
Definition memoOp : memoBinTree Z := trie (fun v => foldOp (hd v)).
Definition memoFoldOp (t : binTree) : Z := untrie t memoOp.
Eval vm_compute in (memoFoldOp (fullTree 0)).
Time Eval vm_compute in memoFoldOp (fullTree 12).
Time Eval vm_compute in memoFoldOp (fullTree 11).
Time Eval vm_compute in memoFoldOp (fullTree 12).
Time Eval vm_compute in memoFoldOp (fullTree 11).
Time Eval vm_compute in (memoFoldOp (fullTree 13),memoFoldOp (fullTree 13)).
Time Eval vm_compute in (foldOp (fullTree 13),foldOp (fullTree 13)).
(* This memotrie doesn't satisfy the guard condition.
* Coq doesn't know that zip has a sufficent modulus of continuity.
*)
CoFixpoint memoOpX : memoBinTree Z :=
@Build_memoVBinTree Z 1 tt
(NonZero _ _ (@Build_memoVBinTree Z 0 2147483647%Z (FilterZero _) (FilterZero _)))
(NonZero _ _ (zip op memoOpX memoOpX)).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment