Created
August 11, 2025 08:48
-
-
Save mukeshtiwari/eb88e19f10d3cdb8bbbc9f7f5d2b7516 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Section Listmat. | |
Context {Node R : Type} | |
(dec_node : forall (c d : Node), {c = d} + {c <> d}). | |
Fixpoint cross_product (la lb : list Node) : | |
(list (Node * Node)) := | |
match la with | |
| [] => [] | |
| lah :: lat => List.map (fun x => (lah, x)) lb ++ | |
cross_product lat lb | |
end. | |
Definition fun_to_list (l : list Node) | |
(m : Node -> Node -> R) : | |
(list ((Node * Node) * R)) := | |
List.map (fun '(i, j) => (i, j, m i j)) (cross_product l l). | |
Fixpoint list_to_fun (c d : Node) (m : Node -> Node -> R) | |
(l : (list ((Node * Node) * R))) : R := | |
match l with | |
| [] => m c d | |
| ((lha, lhb), lhr) :: lt => | |
match dec_node c lha, dec_node d lhb with | |
| left _, left _ => lhr | |
| _ , _ => list_to_fun c d m lt | |
end | |
end. | |
Theorem equiv_list_fun : ∀ (l : list Node) (m : Node → Node → R) | |
(c d : Node), list_to_fun c d m (fun_to_list l m) = m c d. | |
Proof. | |
intros *. unfold fun_to_list. | |
induction (cross_product l l) as [|(ah, bh) lt ihl]. | |
+ | |
cbn. reflexivity. | |
+ | |
cbn. | |
destruct (dec_node c ah) eqn:ha; | |
destruct (dec_node d bh) eqn:hb. | |
++ | |
subst; reflexivity. | |
++ | |
eapply ihl. | |
++ | |
eapply ihl. | |
++ | |
eapply ihl. | |
Qed. | |
End Listmat. | |
Section Effdef_opt. | |
Variables | |
(Node : Type) | |
(dec_node : forall (c d : Node), {c = d} + {c <> d}) | |
(finN : list Node). | |
Variables | |
(R : Type) | |
(zeroR oneR : R) (* 0 and 1 *) | |
(plusR mulR : R -> R -> R) | |
(eqR : R -> R -> Prop). | |
Local Notation "0" := zeroR. | |
Local Notation "1" := oneR. | |
Local Infix "+" := plusR. | |
Local Infix "*" := mulR. | |
Local Infix "=r=" := eqR (at level 70). | |
(* helper: build association list once *) | |
Definition entries_of (m : Node -> Node -> R) : list ((Node * Node) * R) := | |
fun_to_list finN m. (* reuse your fun_to_list definition *) | |
(* lookup with a provided default function m_default (used only if not found) *) | |
Fixpoint lookup_from_entries (c d : Node) (m : Node -> Node -> R) | |
(l : list ((Node * Node) * R)) : R := | |
match l with | |
| [] => m c d | |
| ((lha, lhb), lhr) :: lt => | |
match dec_node c lha, dec_node d lhb with | |
| left _, left _ => lhr | |
| _ , _ => lookup_from_entries c d m lt | |
end | |
end. | |
(* zero and I using entries are unchanged; I can be defined directly *) | |
Definition zero_matrix_eff : Node -> Node -> R := fun _ _ => 0. | |
Definition I_eff : Node -> Node -> R := fun c d => | |
match dec_node c d with | |
| left _ => 1 | |
| _ => 0 | |
end. | |
(* matrix addition — precompute entries for both operands once *) | |
Definition matrix_add_eff (m1 m2 : Node -> Node -> R) : Node -> Node -> R := | |
let e1 := entries_of m1 in | |
let e2 := entries_of m2 in | |
fun c d => lookup_from_entries c d m1 e1 + lookup_from_entries c d m2 e2. | |
(* fold-sum over an explicit list of nodes *) | |
Definition sum_fn_eff (f : Node -> R) (l : list Node) : R := | |
List.fold_right (fun x y => f x + y) 0 l. | |
(* multiplication: precompute entries for m1 and m2 once *) | |
Definition matrix_mul_gen_eff (m1 m2 : Node -> Node -> R) (nodes : list Node) : Node -> Node -> R := | |
let e1 := entries_of m1 in | |
let e2 := entries_of m2 in | |
fun c d => | |
sum_fn_eff (fun y => | |
let a := lookup_from_entries c y m1 e1 in | |
let b := lookup_from_entries y d m2 e2 in | |
a * b) nodes. | |
Definition matrix_mul_eff (m1 m2 : Node -> Node -> R) := | |
matrix_mul_gen_eff m1 m2 finN. | |
(* fast exponent: just reuse matrix_mul_eff (now cheaper since it reuses entries per multiplication) *) | |
Fixpoint matrix_exp_unary_eff (m : Node -> Node -> R) (n : nat) : Node -> Node -> R := | |
match n with | |
| O => I_eff | |
| S n' => matrix_mul_eff m (matrix_exp_unary_eff m n') | |
end. | |
(* binary exponent as previously *) | |
Fixpoint repeat_op_ntimes_rec_eff (e : Node -> Node -> R) (n : positive) : Node -> Node -> R := | |
match n with | |
| xH => e | |
| xO p => let ret := repeat_op_ntimes_rec_eff e p in matrix_mul_eff ret ret | |
| xI p => let reta := repeat_op_ntimes_rec_eff e p in | |
let retb := matrix_mul_eff reta reta in | |
matrix_mul_eff e retb | |
end. | |
Definition matrix_exp_binary_eff (e : Node -> Node -> R) (n : N) := | |
match n with | |
| N0 => I_eff | |
| Npos p => repeat_op_ntimes_rec_eff e p | |
end. | |
End Effdef_opt. | |
Definition mat : nat -> nat -> nat := fun _ _ => 20. | |
Definition ext := fun x => @matrix_exp_binary_eff | |
nat Nat.eq_dec x nat 0 1 | |
Nat.max Nat.min mat. | |
From Stdlib Require Import Extraction ExtrHaskellBasic | |
ExtrHaskellZInteger ExtrHaskellNatInteger. | |
Extraction Language Haskell. | |
Recursive Extraction ext. | |
Section Effdef. | |
Variables | |
(Node : Type) | |
(dec_node : forall (c d : Node), {c = d} + {c <> d}) | |
(finN : list Node). | |
(* carrier set and the operators *) | |
Variables | |
(R : Type) | |
(zeroR oneR : R) (* 0 and 1 *) | |
(plusR mulR : binary_op R) | |
(eqR : brel R). | |
Declare Scope Mat_scope. | |
Delimit Scope Mat_scope with R. | |
Bind Scope Mat_scope with R. | |
Local Open Scope Mat_scope. | |
Local Notation "0" := zeroR : Mat_scope. | |
Local Notation "1" := oneR : Mat_scope. | |
Local Infix "+" := plusR : Mat_scope. | |
Local Infix "*" := mulR : Mat_scope. | |
Local Infix "=r=" := eqR (at level 70) : Mat_scope. | |
(* returns the cth row of m *) | |
Definition get_row_eff (m : Matrix Node R) (c : Node) : Node -> R := | |
fun d => list_to_fun dec_node c d m (fun_to_list finN m). | |
(* returns the cth column of m *) | |
Definition get_col_eff (m : Matrix Node R) (c : Node) : Node -> R := | |
fun d => list_to_fun dec_node d c m (fun_to_list finN m). | |
(* zero matrix, additive identity of plus *) | |
Definition zero_matrix_eff : Matrix Node R := | |
fun _ _ => 0. | |
(* identity matrix, mulitplicative identity of mul *) | |
(* Idenitity Matrix *) | |
Definition I_eff : Matrix Node R := | |
fun (c d : Node) => | |
match dec_node c d with | |
| left _ => 1 | |
| _ => 0 | |
end. | |
(* transpose the matrix m *) | |
Definition transpose_eff (m : Matrix Node R) : Matrix Node R := | |
fun (c d : Node) => list_to_fun dec_node d c m (fun_to_list finN m). | |
(* pointwise addition to two matrices *) | |
Definition matrix_add_eff (m₁ m₂ : Matrix Node R) : Matrix Node R := | |
fun c d => list_to_fun dec_node c d m₁ (fun_to_list finN m₁) + | |
list_to_fun dec_node c d m₂ (fun_to_list finN m₂). | |
Definition sum_fn_eff (f : Node -> R) (l : list Node) : R := | |
List.fold_right (fun x y => f x + y) 0 l. | |
(* sum of the elements of a matrix *) | |
(* generalised matrix multiplication *) | |
Definition matrix_mul_gen_eff (m₁ m₂ : Matrix Node R) | |
(l : list Node) : Matrix Node R := | |
fun (c d : Node) => | |
sum_fn_eff (fun y => | |
list_to_fun dec_node c y m₁ (fun_to_list finN m₁) * | |
list_to_fun dec_node y d m₂ (fun_to_list finN m₂)) l. | |
(* Specialised form of general multiplicaiton *) | |
Definition matrix_mul_eff (m₁ m₂ : Matrix Node R) := | |
matrix_mul_gen_eff m₁ m₂ finN. | |
Fixpoint matrix_exp_unary_eff (m : Matrix Node R) (n : nat) : Matrix Node R := | |
match n with | |
| 0%nat => I_eff | |
| S n' => matrix_mul_eff m (matrix_exp_unary_eff m n') | |
end. | |
Fixpoint repeat_op_ntimes_rec_eff (e : Matrix Node R) (n : positive) : Matrix Node R := | |
match n with | |
| xH => e | |
| xO p => let ret := repeat_op_ntimes_rec_eff e p in matrix_mul_eff ret ret | |
| xI p => | |
let reta := repeat_op_ntimes_rec_eff e p in | |
let retb := matrix_mul_eff reta reta in | |
matrix_mul_eff e retb | |
end. | |
Definition matrix_exp_binary_eff (e : Matrix Node R) (n : N) := | |
match n with | |
| N0 => I_eff | |
| Npos p => repeat_op_ntimes_rec_eff e p | |
end. | |
End Effdef. | |
Definition mat : nat -> nat -> nat := fun _ _ => 20. | |
Definition ext := fun x => @matrix_exp_binary_eff | |
nat Nat.eq_dec x nat 0 1 | |
Nat.max Nat.min mat. | |
From Stdlib Require Import Extraction ExtrHaskellBasic | |
ExtrHaskellZInteger ExtrHaskellNatInteger. | |
Extraction Language Haskell. | |
Recursive Extraction ext. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment