Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some simple refactoring #19

Merged
merged 3 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions accuracy_proofs/gem_defs.v
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ From LAProof.accuracy_proofs Require Import common
float_acc_lems
list_lemmas
float_tactics.
Set Warnings "-notation-overriden, -parsing".

Set Warnings "-notation-overridden,-ambiguous-paths,-overwriting-delimiting-key".
From mathcomp Require all_ssreflect.

(* General list matrix and vector definitions *)
Section MVGenDefs.
Expand Down Expand Up @@ -65,8 +67,8 @@ end.
Definition in_matrix {T : Type} (A : list (list T)) (a : T) :=
let A' := flat_map (fun x => x) A in In a A'.

Definition matrix_index {A} (m: matrix) (i j: nat) (zero: A) : A :=
nth j (nth i m nil) zero.
Definition matrix_index {A} (zero: A) (m: matrix) (i j: nat) : A :=
List.nth j (List.nth i m nil) zero.

Definition eq_size {T1 T2}
(A : list (list T1)) (B : list (list T2)) := length A = length B /\
Expand Down Expand Up @@ -256,7 +258,7 @@ Notation "A -m B" := (mat_sumR A (map_mat Ropp B)) (at level 40).
Notation "A +m B" := (mat_sumR A B) (at level 40).

Notation "E _( i , j )" :=
(matrix_index E i j 0%R) (at level 15).
(matrix_index 0%R E i j) (at level 15).

Section MVLems.

Expand Down Expand Up @@ -384,7 +386,6 @@ set (z := (zero_vector (length (a::v)) (Zconst t 0))).
rewrite vec_sum_cons.
simpl. unfold vec_sumF in IHv.
rewrite IHv. f_equal.
Search Binary.B2R FT2R.
rewrite <-!B2R_float_of_ftype;
unfold BPLUS, BINOP.
rewrite float_of_ftype_of_float.
Expand Down Expand Up @@ -552,8 +553,8 @@ induction l0; auto.
simpl. assert False by auto; contradiction.
Qed.

Lemma matrix_index_nil {A} (i j: nat) (zero: A) :
matrix_index [] i j zero = zero.
Lemma matrix_index_nil {A} (zero: A) (i j: nat) :
matrix_index zero [] i j = zero.
Proof. unfold matrix_index. destruct i; destruct j; simpl; auto. Qed.

Lemma vec_sumR_nth :
Expand Down Expand Up @@ -617,17 +618,7 @@ Qed.

End MVLems.


From mathcomp Require Import all_ssreflect all_algebra ssrnum.
Require Import VST.floyd.functional_base.

Open Scope R_scope.
Open Scope ring_scope.

Delimit Scope ring_scope with Ri.
Delimit Scope R_scope with R.

Import Order.TTheory GRing.Theory Num.Def Num.Theory.
Import all_ssreflect.

Section SIZEDEFS.

Expand Down Expand Up @@ -933,17 +924,25 @@ Qed.
End MxLems.

Section MMLems.
Lemma nth_map':
forall {A B} (f: A -> B) (d: B) (d': A) i al,
(i < List.length al)%coq_nat ->
List.nth i (List.map f al) d = f (List.nth i al d').
Proof.
induction i; destruct al; simpl; intros; try lia; auto.
apply IHi; lia.
Qed.

Lemma nth_mul' : forall (A : list (list R)) b i j
( Hj : (j < length b)%nat),
(nth 0 (nth i A []) 0%R * nth j b 0%R =
nth j (nth i (map (fun a0 : R => map (Rmult a0) b) (map (hd 0%R) A)) []) 0%R)%R.
(List.nth 0 (List.nth i A []) 0%R * List.nth j b 0%R =
List.nth j (List.nth i (map (fun a0 : R => map (Rmult a0) b) (map (hd 0%R) A)) []) 0%R)%R.
Proof.
move => A. elim: A => [b i j H| a A IH b i j Hj] /=.
destruct i; destruct j => /=; ring.
destruct i => /= //.
rewrite hd_nth => /=.
rewrite (nth_map' (Rmult (nth 0 a 0%R)) 0%R 0%R j b) => //=.
rewrite (nth_map' (Rmult (List.nth 0 a 0%R)) 0%R 0%R j b) => //=.
apply /ssrnat.ltP => //.
specialize (IH b i j Hj). rewrite -IH => //.
Qed.
Expand Down
66 changes: 32 additions & 34 deletions accuracy_proofs/gemm_acc.v
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
Require Import vcfloat.VCFloat.
Require Import List.
Import ListNotations.
Set Warnings "-notation-overridden,-ambiguous-paths,-overwriting-delimiting-key".
From LAProof.accuracy_proofs Require Import common op_defs dotprod_model sum_model.
From LAProof.accuracy_proofs Require Import dot_acc float_acc_lems list_lemmas.
From LAProof.accuracy_proofs Require Import gem_defs mv_mathcomp gemv_acc vec_op_acc.

From LAProof.accuracy_proofs Require Import gem_defs mv_mathcomp gemv_acc(* vec_op_acc*).
From mathcomp.analysis Require Import Rstruct.
Set Warnings "-notation-overriden, -parsing".
From mathcomp Require Import all_ssreflect ssralg ssrnum.
(* From LAProof.accuracy_proofs Require Import mc_extra2. *)

From Coq Require Import ZArith Reals Psatz.
From Coq Require Import Arith.Arith.

Open Scope R_scope.
Open Scope ring_scope.

Delimit Scope ring_scope with Ri.
Delimit Scope R_scope with Re.
Set Warnings "notation-overridden,ambiguous-paths,overwriting-delimiting-key".

Require Import LAProof.accuracy_proofs.vec_op_acc.

Import Order.TTheory GRing.Theory Num.Def Num.Theory.

Expand Down Expand Up @@ -245,10 +243,10 @@ Theorem sMMC_error:
/\ (forall i j, (i < p)%nat -> (j < m)%nat ->
Rabs (eta1 _(i,j)) <= g1 n n)
/\ forall i j : nat,(i < p)%nat -> (j < m)%nat ->
Rabs (matrix_index eta1 i j 0%Re) <= g1 n n
Rabs (matrix_index 0%Re eta1 i j) <= g1 n n
/\ forall i j : nat, (i < p)%nat -> (j < m)%nat ->
Rabs (matrix_index E i j 0%Re) <=
g m * Rabs (matrix_index ((MMCR Ar Br +m E1) +m eta1) i j 0%Re)
Rabs (matrix_index 0%Re E i j) <=
g m * Rabs (matrix_index 0%Re ((MMCR Ar Br +m E1) +m eta1) i j)
/\ size_col E1 m p
/\ size_col eta1 m p
/\ eq_size E (MMCF A B)
Expand Down Expand Up @@ -431,25 +429,25 @@ Theorem mat_axpby_error:
mat_sumR (scaleMR (FT2R x) (Ar +m EA) +m eta1 +m ea)
(scaleMR (FT2R y) (Br +m EB) +m eta2 +m eb)
/\ (forall i j : nat, (i < m)%nat -> (j < n)%nat ->
Rabs (matrix_index EA i j 0%R) <=
g n * Rabs (matrix_index Ar i j 0%R))
Rabs (matrix_index 0%R EA i j) <=
g n * Rabs (matrix_index 0%R Ar i j))
/\ (forall i j : nat, (i < m)%nat -> (j < n)%nat ->
Rabs (matrix_index EB i j 0%R) <=
g n * Rabs (matrix_index Br i j 0%R))
Rabs (matrix_index 0%R EB i j) <=
g n * Rabs (matrix_index 0%R Br i j))
/\ (forall i j : nat, (i < m)%nat -> (j < n)%nat ->
exists d,
matrix_index ea i j 0%R =
matrix_index (scaleMR (FT2R x) (Ar +m EA) +m eta1) i j 0%R * d
matrix_index 0%R ea i j =
matrix_index 0%R (scaleMR (FT2R x) (Ar +m EA) +m eta1) i j * d
/\ Rabs d <= @default_rel t)
/\ (forall i j : nat, (i < m)%nat -> (j < n)%nat ->
exists d,
matrix_index eb i j 0%R =
matrix_index (scaleMR (FT2R y) (Br +m EB) +m eta2) i j 0%R * d
matrix_index 0%R eb i j =
matrix_index 0%R (scaleMR (FT2R y) (Br +m EB) +m eta2) i j * d
/\ Rabs d <= @default_rel t)
/\ (forall i j : nat, (i < m)%nat -> (j < n)%nat ->
Rabs (matrix_index eta1 i j 0%Re) <= g1 n n)
Rabs (matrix_index 0%Re eta1 i j) <= g1 n n)
/\ (forall i j : nat, (i < m)%nat -> (j < n)%nat ->
Rabs (matrix_index eta2 i j 0%Re) <= g1 n n)
Rabs (matrix_index 0%Re eta2 i j) <= g1 n n)
/\ eq_size EA A
/\ eq_size EB A
/\ eq_size ea A
Expand Down Expand Up @@ -550,34 +548,34 @@ Theorem GEMM_error:
exists E0 : matrix,
List.nth k ab1 [::] = E0 *r List.nth k Br [] /\
(forall i j : nat, (i < m)%nat -> (j < n)%nat ->
Rabs (matrix_index E0 i j 0%Re) <= g n * Rabs (matrix_index Ar i j 0%Re)))
Rabs (matrix_index 0%Re E0 i j) <= g n * Rabs (matrix_index 0%Re Ar i j)))
/\ (forall i j : nat, (i < p)%nat -> (j < m)%nat ->
Rabs (matrix_index ab2 i j 0%Re) <= g1 n n)
Rabs (matrix_index 0%Re ab2 i j) <= g1 n n)
/\ (forall i j : nat, (i < p)%nat -> (j < m)%nat ->
Rabs (matrix_index ab3 i j 0%Re) <=
g m * Rabs (matrix_index ((MMCR Ar Br +m ab1) +m ab2) i j 0%Re))
Rabs (matrix_index 0%Re ab3 i j) <=
g m * Rabs (matrix_index 0%Re ((MMCR Ar Br +m ab1) +m ab2) i j))
/\ (forall i j : nat,
(i < p)%nat -> (j < m)%nat ->
Rabs (matrix_index y1 i j 0%Re) <=
g m * Rabs (matrix_index Yr i j 0%Re))
Rabs (matrix_index 0%Re y1 i j) <=
g m * Rabs (matrix_index 0%Re Yr i j))
/\ (forall i j : nat, (i < p)%nat -> (j < m)%nat ->
exists d,
matrix_index ab5 i j 0%Re =
matrix_index
matrix_index 0%Re ab5 i j =
matrix_index 0%Re
(scaleMR (FT2R s1)
(((MMCR Ar Br +m ab1) +m ab2) +m ab3) +m ab4) i j 0%Re * d
(((MMCR Ar Br +m ab1) +m ab2) +m ab3) +m ab4) i j * d
/\ Rabs d <= @default_rel t)
/\ (forall i j : nat, (i < p)%nat -> (j < m)%nat ->
exists d ,
matrix_index y3 i j 0%Re =
matrix_index
(scaleMR (FT2R s2) (Yr +m y1) +m y2) i j 0%Re * d
matrix_index 0%Re y3 i j =
matrix_index 0%Re
(scaleMR (FT2R s2) (Yr +m y1) +m y2) i j * d
/\ Rabs d <= @default_rel t)
/\ (forall i j : nat,
(i < p)%nat -> (j < m)%nat ->
Rabs (matrix_index ab4 i j 0%Re) <= g1 m m)
Rabs (matrix_index 0%Re ab4 i j) <= g1 m m)
/\ (forall i0 j0 : nat, (i0 < p)%nat -> (j0 < m)%nat ->
Rabs (matrix_index y2 i0 j0 0%Re) <= g1 m m).
Rabs (matrix_index 0%Re y2 i0 j0) <= g1 m m).
Proof.
(* len hyps for composing errors *)
have Hlen1 : forall v : seq (ftype t),
Expand Down
45 changes: 18 additions & 27 deletions accuracy_proofs/gemv_acc.v
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ From LAProof.accuracy_proofs Require Import common op_defs dotprod_model sum_mod
dot_acc float_acc_lems list_lemmas
gem_defs mv_mathcomp.
From mathcomp.analysis Require Import Rstruct.
Set Warnings "-notation-overriden, -parsing".
Set Warnings "-notation-overridden,-ambiguous-paths,-overwriting-delimiting-key".
From mathcomp Require Import all_ssreflect ssralg ssrnum.
(* From LAProof.accuracy_proofs Require Import mc_extra2. *)

From Coq Require Import ZArith Reals Psatz.
From Coq Require Import Arith.Arith.

Expand All @@ -21,6 +19,8 @@ Delimit Scope R_scope with Re.
Import Order.TTheory GRing.Theory Num.Def Num.Theory.

From mathcomp.algebra_tactics Require Import ring.
Set Bullet Behavior "Strict Subproofs".


Section MixedErrorList.
(* mixed error bounds over lists *)
Expand Down Expand Up @@ -180,8 +180,6 @@ From mathcomp Require Import matrix all_algebra bigop.

Section MixedErrorMath.

Import VST.floyd.functional_base.

Open Scope R_scope.
Open Scope ring_scope.

Expand All @@ -207,16 +205,13 @@ Notation vr := (vector_to_vc (n.+1) (map FT2R v)).
Hypothesis Hfin : is_finite_vec (A *f v).
Hypothesis Hlen : forall x, In x A -> length x = n.+1.

Notation " i ' " := (Ordinal i) (at level 40).

Notation Av := (vector_to_vc (m.+1) (A *fr v)).

Lemma mat_vec_mul_mixed_error':
exists (E : 'M[R]_(m.+1,n.+1)) (eta : 'cV[R]_m.+1),
Av = (Ar + E) *m vr + eta
/\ (forall i j (Hi : (i < m.+1)%nat) (Hj : (j < n.+1)%nat),
Rabs (E (Hi ') (Hj ')) <= g n.+1 * Rabs (Ar (Hi ') (Hj ')))
/\ forall i (Hi: (i < m.+1)%nat), Rabs (eta (Hi ') 0) <= g1 n.+1 n.+1 .
/\ (forall i j, Rabs (E i j) <= g n.+1 * Rabs (Ar i j))
/\ forall i, Rabs (eta i 0) <= g1 n.+1 n.+1 .
Proof.
have Hlen' : forall x : seq.seq (ftype t), In x A -> Datatypes.length x = length v.
move => x Hin. rewrite Hlen => //. lia.
Expand All @@ -237,8 +232,8 @@ have Hin1 :
apply matrix_sum_preserves_length'.
destruct H4. intros.
rewrite map_length.
set (y := nth 0 A []).
have Hy : In y A. subst y; apply nth_In; lia.
set (y := List.nth 0 A []).
have Hy : In y A. subst y; apply List.nth_In; lia.
specialize (H0 x y H4 Hy); rewrite H0.
apply Hlen'; auto.
move => x Hx.
Expand Down Expand Up @@ -272,20 +267,17 @@ rewrite H6. apply Hlen; auto.
destruct H4.
rewrite /map_mat/mat_sumR/mat_sum/map2 !map_length combine_length
map_length; lia.
split.
{ move => i j Hi Hj.
rewrite -(matrix_to_mx_index E i j).
rewrite -(matrix_to_mx_index (map_mat FT2R A) i j).
have HA : (length A = m.+1) by (subst m; lia).
have Hv : (length v = n.+1) by (subst m; lia).
rewrite HA Hv in H2.
specialize (H2 i j Hi Hj).
subst n => /= //. }
move => i Hi.
rewrite vector_to_vc_index => /= //.
have Hv : (length v = n.+1) by (subst m; lia).
split.
{ move => [i Hi] [j Hj].
rewrite !mxE /=.
rewrite HA Hv in H2.
by apply H2. }
move => [i Hi].
rewrite /vector_to_vc mxE /=.
rewrite Hv in H3.
apply H3. apply nth_In. lia.
apply H3. apply List.nth_In. lia.
Qed.

End MixedErrorMath.
Expand All @@ -306,17 +298,16 @@ Let m := (length A - 1)%nat.
Hypothesis Hlenv1: (length v - 1)%nat = m.

Notation Ar := (matrix_to_mx m.+1 m.+1 (map_mat FT2R A)).
Notation vr := (vector_to_vc m.+1 (map FT2R v)).
Notation vr := (vector_to_vc m.+1 (List.map FT2R v)).

Hypothesis Hfin : is_finite_vec (A *f v).
Hypothesis Hlen : forall x, In x A -> length x = m.+1.

Notation " i ' " := (Ordinal i) (at level 40).

Notation Av' := (vector_to_vc m.+1 (map FT2R (mvF A v))).

Notation "| u |" := (normv u) (at level 40).


Theorem forward_error :
|Av' - (Ar *m vr)| <= (g m.+1 * normM Ar * |vr|) + g1 m.+1 m.+1.
Proof.
Expand All @@ -338,7 +329,7 @@ apply normv_pos.
rewrite /normM mulrC big_max_mul.
apply: le_bigmax2 => i0 _.
rewrite /sum_abs.
rewrite big_mul => [ | i b | ]; try ring.
rewrite big_mul => [ | i b | ]; [ | ring | ].
apply ler_sum => i _.
rewrite mulrC.
destruct i0. destruct i. apply H1.
Expand Down
Loading
Loading