Last active
May 14, 2024 15:32
-
-
Save roconnor/52375c3a81c927c5a8234be858f9830b to your computer and use it in GitHub Desktop.
Memo Trie for funcitons over binary trees in Coq (no support for memoizing structually recursive functions though)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* In relation to https://cstheory.stackexchange.com/questions/40631/how-can-you-build-a-coinductive-memoization-table-for-recursive-functions-over-b *) | |
Require Import Vectors.Vector. | |
Import VectorNotations. | |
Set Primitive Projections. | |
Set Implicit Arguments. | |
Inductive binTree : Set := | |
| Leaf : binTree | |
| Branch : binTree -> binTree -> binTree. | |
Inductive filterZero (A : nat -> Type) : nat -> Type := | |
| FilterZero : filterZero A 0 | |
| NonZero : forall m, (A m) -> filterZero A (S m). | |
Definition getNonZero A n (w : filterZero A (S n)) : A n. | |
inversion_clear w as [|m a]. | |
apply a. | |
Defined. | |
(* memoVBinTree A n is intended to be isomorphic to Vector.t BinTree n -> A *) | |
CoInductive memoVBinTree (A : Type) (n : nat) : Type := | |
{ extract : if n then A else unit | |
; memoLeaf : filterZero (memoVBinTree A) n | |
; memoBranch : filterZero (fun _ => memoVBinTree A (S n)) n | |
}. | |
(* memoBinTree A is intended to be isomorphic to BinTree -> A *) | |
Definition memoBinTree A := memoVBinTree A 1. | |
Definition branch {A R n} (k : memoVBinTree A n -> memoVBinTree A (S (S n)) -> R) | |
(t : memoVBinTree A (S n)) : R := | |
k (getNonZero (memoLeaf t)) (getNonZero (memoBranch t)). | |
CoFixpoint map {A B n} (f : A -> B) (t : memoVBinTree A n) : memoVBinTree B n := | |
{| extract := (if n return (memoVBinTree A n -> if n then B else unit) | |
then fun t => f (extract t) | |
else fun _ => tt) t | |
; memoLeaf := (match n return (memoVBinTree A n -> filterZero (memoVBinTree B) n) with | |
| 0 => fun _ => FilterZero _ | |
| S n0 => fun t => NonZero _ _ (map f (getNonZero (memoLeaf t))) | |
end) t | |
; memoBranch := (if n return (memoVBinTree A n -> filterZero (fun _ => memoVBinTree B (S n)) n) | |
then fun _ => FilterZero _ | |
else fun t => NonZero _ _ (map f (getNonZero (memoBranch t)))) t | |
|}. | |
CoFixpoint zip {A B C n m} (f : A -> B -> C) | |
(ta : memoVBinTree A n) (tb : memoVBinTree B m) : memoVBinTree C (n + m) := | |
{| extract := (if n return (memoVBinTree A n -> if (n + m) then C else unit) | |
then fun ta => (if m return (memoVBinTree B m -> if (0 + m) then C else unit) | |
then fun tb => f (extract ta) (extract tb) | |
else fun _ => tt) tb | |
else fun _ => tt) ta | |
; memoLeaf := (match n return (memoVBinTree A n -> filterZero (memoVBinTree C) (n + m)) with | |
| 0 => fun ta => memoLeaf (map (f (extract ta)) tb) | |
| S n0 => fun ta => NonZero _ _ (zip f (getNonZero (memoLeaf ta)) tb) | |
end) ta | |
; memoBranch := (if n return (memoVBinTree A n -> filterZero (fun _ => memoVBinTree C (S n + m)) (n + m)) | |
then fun ta => memoBranch (map (f (extract ta)) tb) | |
else fun ta => NonZero _ _ (zip f (getNonZero (memoBranch ta)) tb)) ta | |
|}. | |
CoFixpoint split {A} n m (t : memoVBinTree A (n + m)) : memoVBinTree (memoVBinTree A m) n := | |
{| extract := (if n return (memoVBinTree A (n + m) -> if n then (memoVBinTree A m) else unit) | |
then fun t => t | |
else fun _ => tt) t | |
; memoLeaf := (match n return (memoVBinTree A (n + m) -> filterZero (memoVBinTree (memoVBinTree A m)) n) with | |
| 0 => fun _ => FilterZero _ | |
| S n0 => fun t => NonZero _ _ (split _ _ (getNonZero (memoLeaf t))) | |
end) t | |
; memoBranch := (match n return (memoVBinTree A (n + m) -> filterZero (fun _ => memoVBinTree (memoVBinTree A m) (S n)) n) with | |
| 0 => fun _ => FilterZero _ | |
| S n0 => fun t => NonZero _ _ (split (S (S n0)) m (getNonZero (memoBranch t))) | |
end) t | |
|}. | |
CoFixpoint trie {A n} (f : Vector.t binTree n -> A) : memoVBinTree A n := | |
{| extract := (if n return ((Vector.t binTree n -> A) -> if n then A else unit) | |
then fun f => f (nil binTree) | |
else fun _ => tt) f | |
; memoLeaf := (match n return ((Vector.t binTree n -> A) -> filterZero (memoVBinTree A) n) with | |
| 0 => fun _ => FilterZero _ | |
| S n0 => fun f => NonZero _ _ (trie (fun v => f (cons _ Leaf _ v))) | |
end) f | |
; memoBranch := (if n return ((Vector.t binTree n -> A) -> filterZero (fun _ => memoVBinTree A (S n)) n) | |
then fun _ => FilterZero _ | |
else fun f => NonZero _ _ (trie (fun v => f (cons _ (Branch (hd v) (hd (tl v))) _ (tl (tl v)))))) f | |
|}. | |
Fixpoint untrie (t : binTree) : forall {A}, (memoBinTree A) -> A := | |
fun A => branch (fun l b => | |
match t with | |
| Leaf => extract l | |
| Branch tl tr => untrie tl (map (untrie tr) (split 1 1 b)) | |
end). | |
(* | |
* Example. | |
*) | |
Require Import ZArith. | |
(* This operation could stand to be a bit more expensive *) | |
Definition op (x y : Z) : Z := Z.modulo ((x + y + 1)*(x + y)/2 + y) 18446744073709551557%Z. | |
Fixpoint foldOp (t : binTree) : Z := | |
match t with | |
| Leaf => 2147483647%Z | |
| Branch tl tr => op (foldOp tl) (foldOp tr) | |
end. | |
Fixpoint fullTree (n : nat) : binTree := | |
match n with | |
| 0 => Leaf | |
| (S n0) => let rec := fullTree n0 in Branch rec rec | |
end. | |
(* memoTrie, but with no sharing for recursive calls of foldOp *) | |
Definition memoOp : memoBinTree Z := trie (fun v => foldOp (hd v)). | |
Definition memoFoldOp (t : binTree) : Z := untrie t memoOp. | |
Eval vm_compute in (memoFoldOp (fullTree 0)). | |
Time Eval vm_compute in memoFoldOp (fullTree 12). | |
Time Eval vm_compute in memoFoldOp (fullTree 11). | |
Time Eval vm_compute in memoFoldOp (fullTree 12). | |
Time Eval vm_compute in memoFoldOp (fullTree 11). | |
Time Eval vm_compute in (memoFoldOp (fullTree 13),memoFoldOp (fullTree 13)). | |
Time Eval vm_compute in (foldOp (fullTree 13),foldOp (fullTree 13)). | |
(* This memotrie doesn't satisfy the guard condition. | |
* Coq doesn't know that zip has a sufficent modulus of continuity. | |
*) | |
CoFixpoint memoOpX : memoBinTree Z := | |
@Build_memoVBinTree Z 1 tt | |
(NonZero _ _ (@Build_memoVBinTree Z 0 2147483647%Z (FilterZero _) (FilterZero _))) | |
(NonZero _ _ (zip op memoOpX memoOpX)). |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment