Created
September 25, 2018 23:34
-
-
Save erutuf/5523067c5ff6792430c63c4a2cd29bd3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Require Import List PeanoNat Arith. | |
Import ListNotations. | |
Set Implicit Arguments. | |
Lemma fold_left_map : forall (A B C : Type)(f : A -> B) g l (init:C), | |
fold_left g (map f l) init = fold_left (fun x y => g x (f y)) l init. | |
Proof. | |
induction l; simpl; auto. | |
Qed. | |
Definition sum xs := fold_left Nat.add xs 0. | |
Definition sum_of_square xs := sum (map Nat.square xs). | |
Hint Unfold sum sum_of_square : my_db. | |
Hint Rewrite fold_left_map map_map : my_rewrite_db. | |
Ltac optimize db rdb:= | |
repeat autounfold with db; | |
eexists; intros; | |
match goal with | |
|- ?X = ?Y => set (P := fun y => y = Y); enough (P X) by auto | |
end; | |
autorewrite with rdb; reflexivity. | |
Definition ss_optimized_sig : { prog | forall xs, sum_of_square xs = prog xs }. | |
Proof. | |
optimize my_db my_rewrite_db. | |
Defined. | |
Definition ss_optimized := Eval simpl in proj1_sig ss_optimized_sig. | |
Print ss_optimized. | |
(* result : | |
[[[ | |
ss_optimized = | |
fun xs : list nat => fold_left (fun x y : nat => x + y * y) xs 0 | |
: list nat -> nat | |
]]] | |
*) | |
Fixpoint fold_nat (A : Type) (f : A -> nat -> A) (start len : nat) (init : A) := | |
match len with | |
| O => init | |
| S len' => fold_nat f (S start) len' (f init start) | |
end. | |
Notation "[ n ,, m ]" := (seq n (S (m - n))). | |
Lemma fold_seq_fold_nat : forall (A : Type) st ed (f : A -> nat -> A) init, | |
fold_left f [st,,ed] init = fold_nat f st (S (ed - st)) init. | |
Proof. | |
intros. | |
remember (S (ed - st)) as m. | |
clear Heqm. revert st init. | |
induction m; intros; simpl; auto. | |
Qed. | |
Hint Rewrite fold_seq_fold_nat : my_rewrite_db. | |
Definition example n := sum (map Nat.square [1,,n]). | |
Hint Unfold example : my_db. | |
Definition example_optimized_sig : { prog | forall n, example n = prog n }. | |
Proof. | |
optimize my_db my_rewrite_db. | |
Defined. | |
Definition example_optimized := Eval simpl in proj1_sig example_optimized_sig. | |
Print example_optimized. | |
(* result : | |
[[[ | |
example_optimized = | |
fun n : nat => fold_nat (fun x y : nat => x + Nat.square y) 2 (n - 1) (Nat.square 1) | |
: nat -> nat | |
]]] | |
*) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment