Created October 30, 2024 14:17
Verified implementation of an efficient (O(n)) fulcrum algorithm: given a sequence seq, returns the index i that minimizes |sum(seq[..i]) - sum(seq[i..])|.
Require Import Coq.Lists.List.
Require Import Coq.ZArith.ZArith.
Require Import Lia.
Import ListNotations.
Theorem map_nonempty {A B} {f : A -> B} {xs} : xs <> [] -> map f xs <> [].
Proof. now destruct xs. Qed.
Definition sum : list Z -> Z := fold_right Z.add 0%Z.
Definition scan_left {A B} (f : B -> A -> B) (z : B) (xs : list A) : list B
:= fold_right (fun x k acc => acc :: k (f acc x)) (fun acc => [acc]) xs z.
Theorem scan_left_step {A B} (f : B -> A -> B) z x xs
: scan_left f z (x :: xs) = z :: scan_left f (f z x) xs.
Proof. reflexivity. Qed.
Theorem scan_left_assoc {A B} g (f : B -> A -> B) z xs
: (forall acc x, f (g acc) x = g (f acc x))
-> scan_left f (g z) xs = map g (scan_left f z xs).
intros e. revert z.
induction xs as [ | x xs IH]; auto.
simpl. intros z. f_equal.
rewrite e. apply IH.
Theorem scan_left_length {A B} (f : B -> A -> B) z xs
: length (scan_left f z xs) = S (length xs).
revert z. induction xs; simpl; auto.
Definition fulcrum_candidates (xs : list Z) : list Z
:= scan_left (fun acc x => acc + 2 * x)%Z (- (sum xs))%Z xs.
Theorem fulcrum_candidates_nth xs n
: (n <= length xs)
-> nth_error (fulcrum_candidates xs) n
= Some (sum (firstn n xs) - sum (skipn n xs))%Z.
revert n. induction xs as [ | x xs IH].
- destruct n; auto.
simpl. intros ?. contradiction (Nat.nle_succ_0 n).
- destruct n as [ | n]; auto. intros p.
unfold fulcrum_candidates.
simpl sum.
rewrite scan_left_step.
replace (-(x + sum xs) + 2 * x)%Z with (x + -sum xs)%Z.
2:{ lia. }
rewrite (scan_left_assoc (fun z => x + z)%Z).
2:{ lia. }
fold (fulcrum_candidates xs).
rewrite <- Z.add_sub_assoc.
apply map_nth_error.
apply IH.
now apply le_S_n.
Theorem fulcrum_candidates_nonempty {xs : list Z}
: fulcrum_candidates xs <> [].
Proof. now destruct xs. Qed.
Definition fulcrum_candidates_length (xs : list Z)
: length (fulcrum_candidates xs) = S (length xs).
Proof. apply scan_left_length. Qed.
Definition minimum_by {A} (le : A -> A -> bool) (xs : list A) (p : xs <> []) : A
match xs, p with
| [], _ => ltac:(contradiction)
| x :: xs, _ => fold_left (fun acc x => if le acc x then acc else x) xs x
Theorem minimum_by_In {A} (le : A -> A -> bool) xs p
: In (minimum_by le xs p) xs.
destruct xs as [ | x0 xs].
- contradiction.
- simpl minimum_by.
clear p. revert x0.
induction xs as [ | x xs IH].
+ simpl; auto.
+ intros x0. simpl fold_left.
destruct (le x0 x).
* specialize (IH x0).
destruct IH as [<- | IH].
** now left.
** right. now right.
* specialize (IH x).
now right.
Theorem minimum_by_optimal {A} (le : A -> A -> bool) xs p a
(connex : forall x y, le x y = true \/ le y x = true)
(trans : forall x y z, le x y = true -> le y z = true -> le x z = true)
: In a xs -> le (minimum_by le xs p) a = true.
assert (refl : forall x, le x x = true).
intros x.
now destruct (connex x x).
destruct xs as [ | x0 xs].
- contradiction.
- simpl. clear p.
revert a x0. induction xs as [ | x xs IH].
+ now intros a x0 [-> | []].
+ intros a x0 q. simpl.
destruct (le x0 x) eqn:e.
* destruct q as [ | [-> | ]]; auto.
apply (trans _ x0); auto.
* destruct q as [-> | ]; auto.
apply (trans _ x); auto.
destruct (connex a x); auto.
rewrite e in *. discriminate.
Fixpoint enumerate_from {A} (n : nat) (xs : list A) : list (nat * A) :=
match xs with
| [] => []
| x :: xs => (n, x) :: enumerate_from (S n) xs
Theorem enumerate_from_length {A} (xs : list A) n
: length (enumerate_from n xs) = length xs.
revert n. induction xs; simpl; auto.
Theorem enumerate_from_nth_error {A} n (xs : list A) k
: nth_error (enumerate_from n xs) k = option_map (pair (k + n)) (nth_error xs k).
revert n xs. induction k as [| k IH]; destruct xs as [ | x xs]; auto.
now rewrite <- Nat.add_succ_r.
Definition enumerate {A} (xs : list A) : list (nat * A) := enumerate_from 0 xs.
Theorem enumerate_length {A} (xs : list A) : length (enumerate xs) = length xs.
Proof. apply enumerate_from_length. Qed.
Theorem enumerate_nth_error {A} (xs : list A) n
: nth_error (enumerate xs) n = option_map (pair n) (nth_error xs n).
unfold enumerate.
rewrite (enumerate_from_nth_error 0).
now rewrite Nat.add_0_r.
Theorem enumerate_In_nth_error {A} (xs : list A) p
: In p (enumerate xs) <-> nth_error xs (fst p) = Some (snd p).
destruct p as [i x]. simpl.
- intros q.
destruct (In_nth_error _ _ q) as [j e].
rewrite enumerate_nth_error in *.
assert (i = j) as <-.
destruct nth_error; simpl in *.
- now injection e.
- discriminate.
destruct nth_error; simpl in *.
+ injection e. now intros ->.
+ discriminate.
- intros q.
apply (nth_error_In _ i).
now rewrite enumerate_nth_error, q.
Theorem enumerate_nonempty {A} {xs : list A} (p : xs <> []) : enumerate xs <> [].
now destruct xs.
Definition min_index_by {A} (le : A -> A -> bool) (xs : list A) (p : xs <> [])
: nat
:= fst (minimum_by
(fun '(_, a) '(_, b) => le a b)
(enumerate xs)
(enumerate_nonempty p)
Theorem min_index_by_valid {A} (le : A -> A -> bool) xs p
: min_index_by le xs p < length xs.
unfold min_index_by.
set (m := minimum_by _ _ _).
rewrite <- enumerate_length.
apply nth_error_Some.
rewrite enumerate_nth_error.
rewrite (proj1 (enumerate_In_nth_error _ _)).
- simpl. discriminate.
- apply minimum_by_In.
Theorem min_index_by_optimal {A} (le : A -> A -> bool) xs p y
(connex : forall x y, le x y = true \/ le y x = true)
(trans : forall x y z, le x y = true -> le y z = true -> le x z = true)
: In y xs -> exists x, nth_error xs (min_index_by le xs p) = Some x
/\ le x y = true.
unfold min_index_by. intros q.
set (le' := fun '(_, a) '(_, b) => le a b).
set (m := minimum_by _ _ _).
exists (snd m).
- apply enumerate_In_nth_error.
apply minimum_by_In.
- destruct (In_nth_error _ _ q) as [j r].
enough (le' m (j, y) = true).
{ destruct m. now simpl in *. }
apply minimum_by_optimal.
+ intros [] []. simpl. auto.
+ intros [] [] []. simpl. eauto.
+ now apply enumerate_In_nth_error.
Definition fulcrum (xs : list Z) : nat
:= min_index_by Z.leb (map Z.abs (fulcrum_candidates xs))
(map_nonempty fulcrum_candidates_nonempty).
Definition fulcrum_metric (xs : list Z) (n : nat) : Z
:= Z.abs (sum (firstn n xs) - sum (skipn n xs))%Z.
Theorem fulcrum_valid (xs : list Z)
: fulcrum xs <= length xs /\ forall j,
j <= length xs -> (fulcrum_metric xs (fulcrum xs) <= fulcrum_metric xs j)%Z.
unfold fulcrum.
set (cand := map Z.abs _).
set (p := _ : cand <> []).
assert (forall j v, nth_error cand j = Some v -> fulcrum_metric xs j = v) as e.
intros j v q.
assert (nth_error cand j <> None) as r.
{ intros ?. rewrite q in *. discriminate. }
unfold cand, fulcrum_metric in *.
erewrite map_nth_error in q.
apply fulcrum_candidates_nth.
apply nth_error_Some in r.
rewrite map_length in r.
rewrite fulcrum_candidates_length in r.
now apply Nat.lt_succ_r.
now inversion q.
- apply Nat.lt_succ_r.
rewrite <- fulcrum_candidates_length.
erewrite <- map_length.
apply min_index_by_valid.
- intros j q.
assert (j < length cand) as r.
unfold cand.
rewrite map_length, fulcrum_candidates_length.
now apply <- Nat.lt_succ_r.
assert (exists v, nth_error cand j = Some v) as [v].
apply nth_error_Some in r.
destruct nth_error; eauto.
now destruct r.
destruct (min_index_by_optimal Z.leb cand p v) as [w []].
+ intros x y.
rewrite Z.leb_compare, Z.leb_compare, Z.compare_antisym.
destruct (_ ?= _)%Z; auto.
+ apply Zle_bool_trans.
+ eapply nth_error_In. eauto.
+ erewrite e, e; eauto.
now apply Z.leb_le.
