Created
May 14, 2018 18:31
-
-
Save roconnor/286d0f21af36c2e97e74338f10a4315b to your computer and use it in GitHub Desktop.
Memoization for structurally recursive funcitons over binary trees in Coq (with poor performance).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Require Import Arith. | |
Require Import Omega. | |
Set Primitive Projections. | |
Set Implicit Arguments. | |
Inductive binTree : Set := | |
| Leaf : binTree | |
| Branch : binTree -> binTree -> binTree. | |
(* We write several arithmetic lemmas whose equalities we need to be transparent for evaluation purposes. *) | |
Fixpoint arith1 a b : a + b - a = b. | |
destruct a. | |
destruct b; reflexivity. | |
simpl. | |
apply arith1. | |
Defined. | |
Fixpoint arith2 a b c (Ha : a <= b) : (a + S (b - a + c)) = S b + c. | |
destruct a. | |
destruct b; reflexivity. | |
destruct b. | |
abstract omega. | |
simpl. | |
f_equal. | |
apply arith2. | |
abstract omega. | |
Defined. | |
Fixpoint arith4 a b c (Ha : a <= b) : a + (b - a + c) = b + c. | |
destruct a. | |
destruct b; reflexivity. | |
destruct b. | |
abstract omega. | |
simpl. | |
f_equal. | |
apply arith4. | |
abstract omega. | |
Defined. | |
Fixpoint arith5 a b : a + S b = S (a + b). | |
destruct a. | |
reflexivity. | |
simpl. | |
f_equal. | |
apply arith5. | |
Defined. | |
Lemma arith6 a b c (Ha : a <= b) : a + (b - a + S c) = S b + c. | |
rewrite arith4. | |
apply arith5. | |
assumption. | |
Defined. | |
Lemma arith7 a : a + 0 = a. | |
induction a. | |
reflexivity. | |
simpl. | |
f_equal. | |
apply IHa. | |
Defined. | |
Fixpoint arith8 a : a - a = 0. | |
destruct a. | |
reflexivity. | |
simpl. | |
apply arith8. | |
Defined. | |
Fixpoint arith9 a b (Ha : a <= b) (Hb : b <= a) : a = b. | |
destruct a; destruct b. | |
- reflexivity. | |
- abstract omega. | |
- abstract omega. | |
- f_equal. | |
apply arith9; abstract omega. | |
Defined. | |
Fixpoint arith10 a b (Hb : b <= a) : S (a - b) = S a - b. | |
destruct a; destruct b; try reflexivity. | |
abstract omega. | |
apply (arith10 a b). | |
abstract omega. | |
Defined. | |
(* The type (catalanTriangle A n k) is isomorphic to A^C(n,k) where C(n,k) is an entry from | |
* Catalan's triangle (see https://en.wikipedia.org/wiki/Catalan%27s_triangle). | |
*) | |
Inductive catalanTriangleBody (A : Set) (rec : nat -> nat -> Set) : nat -> nat -> Set := | |
| ctbTip : forall n, A -> catalanTriangleBody A rec n 0 | |
| ctbNode : forall n k, k <= n -> rec (S n) k -> rec n (S k) -> catalanTriangleBody A rec (S n) (S k) | |
| ctbEmpty : forall n k, n <= k -> catalanTriangleBody A rec n (S k). | |
(* Making this coninductive in an attempt to keep evaluation lazy *) | |
CoInductive catalanTriangle (A : Set) (n k : nat) : Set := | |
{ ctBody : catalanTriangleBody A (catalanTriangle A) n k }. | |
Definition ctbMap (A B: Set) (f : A -> B) (rec1 rec2 : nat -> nat -> Set) (frec : forall n k, rec1 n k -> rec2 n k) n k (ctb : catalanTriangleBody A rec1 n k) : catalanTriangleBody B rec2 n k := | |
match ctb with | |
| ctbTip _ n a => ctbTip rec2 n (f a) | |
| ctbNode _ _ Hk r1 r2 => ctbNode _ _ Hk (frec _ _ r1) (frec _ _ r2) | |
| ctbEmpty _ _ Hk => ctbEmpty _ _ Hk | |
end. | |
CoFixpoint ctMap (A B: Set) (f : A -> B) n k (table : catalanTriangle A n k) : catalanTriangle B n k := | |
{| ctBody := ctbMap f (catalanTriangle B) (ctMap f) (ctBody table) |}. | |
(* The type (catalan A n) is isomorphic to A^C(n) where C(n) is the nth Catalan number. | |
*) | |
Definition catalan A n := catalanTriangle A n n. | |
Definition ctTip A n (table : catalanTriangle A n 0) : A. | |
destruct table as [table]. | |
inversion_clear table; assumption. | |
Defined. | |
Definition ctClose A n k (table : catalanTriangle A (S n) k) : catalanTriangle A n k. | |
destruct table as [table]. | |
inversion_clear table as [n0 a|n0 k0 H0 rec1 rec2|n0 k0 H0]. | |
- constructor; constructor. | |
exact a. | |
- exact rec2. | |
- constructor; constructor. | |
abstract omega. | |
Defined. | |
Definition ctOpen A n k (H : S k <= n) (table : catalanTriangle A n (S k)) : catalanTriangle A n k. | |
destruct table as [table]. | |
inversion_clear table as [n0 a|n0 k0 H0 rec1 rec2|n0 k0 H0]. | |
- exact rec1. | |
- abstract omega. | |
Defined. | |
(* binary trees parameterize by the number of branches (which is equal to one less than the number of leaves). *) | |
Inductive ornBinTree : nat -> Set := | |
| OrnLeaf : ornBinTree 0 | |
| OrnBranch : forall n m, (m <= n) -> ornBinTree m -> ornBinTree (n - m) -> ornBinTree (S n). | |
Fixpoint lookup A n k s (tree : ornBinTree s) : catalanTriangle A (s + n) (s + k) -> catalanTriangle A n k. | |
intros table. | |
destruct (le_le_S_dec k n). | |
destruct tree as [|n0 m0 H tree1 tree2]. | |
exact table. | |
apply ctClose. | |
apply (lookup _ _ _ _ tree2). | |
apply ctOpen. | |
abstract omega. | |
apply (lookup _ _ _ _ tree1). | |
rewrite (arith6 n H). | |
rewrite (arith2 k H). | |
exact table. | |
constructor. | |
destruct k. | |
abstract omega. | |
apply ctbEmpty. | |
abstract omega. | |
Defined. | |
Definition lookupCatalan A n (table : catalan A n) (tree : ornBinTree n) : A. | |
eapply ctTip. | |
eapply (lookup 0 0);[apply tree|]. | |
rewrite arith7. | |
apply table. | |
Defined. | |
CoFixpoint construct (A : Set) n k | |
(f : forall m, m <= k -> catalanTriangle (catalan A m) (n - m) (k - m)) | |
(t : catalanTriangle A n (S k)) : catalanTriangle A (S n) (S k). | |
constructor. | |
destruct (le_le_S_dec k n) as [Hk|Hk];[|exact (ctbEmpty _ _ Hk)]. | |
apply (ctbNode _ _ Hk). | |
- destruct k as [|k]. | |
constructor; constructor. | |
destruct n. | |
specialize (f 0 (le_n 0)). | |
exact (ctTip (ctTip f)). | |
eapply ctTip. | |
refine (ctOpen _ t). | |
abstract omega. | |
destruct n as [|n];[abstract omega|]. | |
apply construct. | |
intros m Hm. | |
apply ctOpen. abstract omega. | |
rewrite (arith10 Hm). | |
apply f. | |
exact (le_S _ _ Hm). | |
destruct (le_le_S_dec (S n) (S k)). | |
specialize (f _ (le_n (S k))). | |
rewrite (arith9 l Hk). | |
rewrite (arith8 (S k)) in f. | |
exact (ctTip f). | |
apply ctOpen. | |
assumption. | |
exact t. | |
- destruct (le_le_S_dec n k). | |
constructor. | |
apply ctbEmpty. | |
assumption. | |
destruct n;[abstract omega|]. | |
apply construct. | |
intros m Hm. | |
apply ctClose. | |
rewrite (@arith10 n m) by (abstract omega). | |
apply f. | |
apply Hm. | |
apply ctClose. | |
apply t. | |
Defined. | |
Definition convolution (A : Set) n (rec : forall m, m <= n -> catalan (catalan A m) (n - m)): | |
catalan A (S n). | |
apply construct. | |
- apply rec. | |
- constructor; constructor. | |
apply le_n. | |
Defined. | |
Definition zip (A B C: Set) n m (f : A -> B -> C) (t1 : catalan A n) (t2 : catalan B m) : catalan (catalan C m) n := | |
ctMap (fun a => ctMap (f a) t2) t1. | |
Inductive catalans (A : Set) : nat -> Set := | |
| cnil : catalans A 0 | |
| ccons : forall n, catalans A n -> catalan A n -> catalans A (S n). | |
Definition lookupCatalans A n m (Hm : m < n) (cs : catalans A n) : catalan A m. | |
induction cs. | |
abstract omega. | |
destruct (le_lt_dec n m). | |
rewrite (@arith9 m n) by (try assumption; abstract omega). | |
exact c. | |
apply IHcs; assumption. | |
Defined. | |
CoInductive catalanStreamBody (A : Set) (n : nat) : Set := | |
{ csbHead : catalan A n | |
; csbTail : catalanStreamBody A (S n) | |
}. | |
Definition catalanStream A := catalanStreamBody A 0. | |
Definition lookupCatalanStreamBody A n m (cs : catalanStreamBody A n) : catalan A (n + m). | |
revert n cs. | |
induction m; | |
intros n cs. | |
rewrite arith7 . | |
exact (csbHead cs). | |
rewrite arith5. | |
exact (IHm _ (csbTail cs)). | |
Defined. | |
Definition lookupCatalanStream A n (cs : catalanStream A) : catalan A n := | |
lookupCatalanStreamBody n cs. | |
Section Memo. | |
Variable A : Set. | |
Variable a : A. | |
Variable op : A -> A -> A. | |
Definition step n (recs : catalans A n) : catalan A n. | |
destruct n. | |
repeat constructor. | |
exact a. | |
apply convolution. | |
intros m Hm. | |
apply (zip op); refine (lookupCatalans _ recs); abstract omega. | |
Defined. | |
CoFixpoint stream n (recs: catalans A n) : catalanStreamBody A n := | |
let next := step recs in | |
{| csbHead := next | |
; csbTail := stream (ccons recs next) | |
|}. | |
Definition memo : catalanStream A := stream (cnil A). | |
End Memo. | |
Definition decorate (t : binTree) : {n : nat & ornBinTree n}. | |
induction t. | |
exists 0. | |
constructor. | |
destruct IHt1 as [n1 ot1]. | |
destruct IHt2 as [n2 ot2]. | |
exists (S (n1 + n2)). | |
refine (OrnBranch _ ot1 _). | |
abstract omega. | |
rewrite (arith1 n1 n2). | |
exact ot2. | |
Defined. | |
Definition unmemo A (table : catalanStream A) (t : binTree) : A. | |
destruct (decorate t) as [n ot]. | |
exact (lookupCatalan (lookupCatalanStream _ table) ot). | |
Defined. | |
Section SmallCheck. | |
Variable A : Set. | |
Variable a : A. | |
Variable op : A -> A -> A. | |
Let table : catalanStream A := memo a op. | |
Let check (t : binTree) : A := unmemo table t. | |
Eval vm_compute in check Leaf. | |
Eval vm_compute in check (Branch Leaf Leaf). | |
Eval vm_compute in check (Branch Leaf (Branch Leaf Leaf)). | |
Eval vm_compute in check (Branch (Branch Leaf Leaf) Leaf). | |
Eval vm_compute in check (Branch Leaf (Branch Leaf (Branch Leaf Leaf))). | |
Eval vm_compute in check (Branch Leaf (Branch (Branch Leaf Leaf) Leaf)). | |
Eval vm_compute in check (Branch (Branch Leaf Leaf) (Branch Leaf Leaf)). | |
Eval vm_compute in check (Branch (Branch Leaf (Branch Leaf Leaf)) Leaf). | |
Eval vm_compute in check (Branch (Branch (Branch Leaf Leaf) Leaf) Leaf). | |
Eval vm_compute in check (Branch Leaf (Branch Leaf (Branch Leaf (Branch Leaf Leaf)))). | |
Eval vm_compute in check (Branch Leaf (Branch Leaf (Branch (Branch Leaf Leaf) Leaf))). | |
Eval vm_compute in check (Branch Leaf (Branch (Branch Leaf Leaf) (Branch Leaf Leaf))). | |
Eval vm_compute in check (Branch Leaf (Branch (Branch Leaf (Branch Leaf Leaf)) Leaf)). | |
Eval vm_compute in check (Branch Leaf (Branch (Branch (Branch Leaf Leaf) Leaf) Leaf)). | |
Eval vm_compute in check (Branch (Branch Leaf Leaf) (Branch Leaf (Branch Leaf Leaf))). | |
Eval vm_compute in check (Branch (Branch Leaf Leaf) (Branch (Branch Leaf Leaf) Leaf)). | |
Eval vm_compute in check (Branch (Branch Leaf (Branch Leaf Leaf)) (Branch Leaf Leaf)). | |
Eval vm_compute in check (Branch (Branch (Branch Leaf Leaf) Leaf) (Branch Leaf Leaf)). | |
Eval vm_compute in check (Branch (Branch Leaf (Branch Leaf (Branch Leaf Leaf))) Leaf). | |
Eval vm_compute in check (Branch (Branch Leaf (Branch (Branch Leaf Leaf) Leaf)) Leaf). | |
Eval vm_compute in check (Branch (Branch (Branch Leaf Leaf) (Branch Leaf Leaf)) Leaf). | |
Eval vm_compute in check (Branch (Branch (Branch Leaf (Branch Leaf Leaf)) Leaf) Leaf). | |
Eval vm_compute in check (Branch (Branch (Branch (Branch Leaf Leaf) Leaf) Leaf) Leaf). | |
End SmallCheck. | |
(* | |
* Example. | |
*) | |
Require Import ZArith. | |
(* This operation could stand to be a bit more expensive *) | |
Definition a : Z := 2147483647%Z. | |
Definition op (x y : Z) : Z := Z.modulo ((x + y + 1)*(x + y)/2 + y) 18446744073709551557%Z. | |
Fixpoint foldOp (t : binTree) : Z := | |
match t with | |
| Leaf => a | |
| Branch tl tr => op (foldOp tl) (foldOp tr) | |
end. | |
Fixpoint fullTree (n : nat) : binTree := | |
match n with | |
| 0 => Leaf | |
| (S n0) => let rec := fullTree n0 in Branch rec rec | |
end. | |
Eval vm_compute in (decorate (fullTree 4)). | |
Definition memoOp : catalanStream Z := memo a op. | |
Definition memoFoldOp (t : binTree) : Z := unmemo memoOp t. | |
(* | |
Eval vm_compute in foldOp (Branch Leaf (Branch (Branch Leaf Leaf) Leaf)). | |
Eval vm_compute in memoFoldOp (Branch Leaf (Branch (Branch Leaf Leaf) Leaf)). | |
*) | |
(* | |
Time Eval lazy beta delta iota zeta in (memoFoldOp (fullTree 8)). | |
Time Eval lazy beta delta iota zeta in (unmemo memoOp (fullTree 8)). | |
Time Eval lazy beta delta iota zeta in (memoFoldOp (fullTree 8),memoFoldOp (fullTree 8)). | |
Time Eval lazy beta delta iota zeta in (let tbl := memoOp in (unmemo tbl (fullTree 8),unmemo tbl (fullTree 8))). | |
*) | |
(* | |
Time Eval vm_compute in (memoFoldOp (fullTree 8)). | |
*) | |
Time Eval vm_compute in (memoFoldOp (fullTree 9)). | |
Time Eval vm_compute in (memoFoldOp (fullTree 8)). | |
Time Eval vm_compute in (memoFoldOp (fullTree 9)). | |
Time Eval vm_compute in (memoFoldOp (fullTree 8)). | |
Fixpoint iterOp (n : nat) : Z := | |
match n with | |
| O => 2147483647%Z | |
| S n => let rec := iterOp n in op rec rec | |
end. | |
Time Eval lazy beta delta iota zeta in (iterOp 9). | |
Time Eval vm_compute in (memoFoldOp (fullTree 10)). | |
Time Eval vm_compute in (memoFoldOp (fullTree 10)). |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment