-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathla_solve.fypp
146 lines (122 loc) · 5.02 KB
/
la_solve.fypp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#:include "common.fypp"
module la_solve
use la_constants
use la_blas
use la_lapack
use la_state_type
use iso_fortran_env,only:real32,real64,real128,int8,int16,int32,int64,stderr => error_unit
implicit none(type,external)
private
!> @brief Solve a system of linear equations A * X = B.
!!
!! This function computes the solution to a real system of linear equations:
!!
!! \f$ A \cdot X = B \f$
!!
!! where A is an `n x n` square matrix, and B is either a vector (`n`) or a matrix (`n x nrhs`).
!! The solution X is returned as an allocatable array.
!!
!! @param[in,out] A The input square matrix of size `n x n`. If `overwrite_a` is true,
!! the contents of A may be modified during computation.
!! @param[in] B The right-hand side vector (size `n`) or matrix (size `n x nrhs`).
!! @param[in] overwrite_a (Optional) If true, A may be overwritten and destroyed. Default is false.
!! @param[out] err (Optional) A state return flag. If an error occurs and `err` is not provided,
!! the function will stop execution.
!!
!! @return Solution matrix X of size `n` (for a single right-hand side) or `n x nrhs`.
!!
!! @note This function relies on LAPACK LU decomposition based solvers `*GESV`.
!!
!! @warning If `overwrite_a` is enabled, the original contents of A may be lost.
!!
public :: solve
interface solve
#:for nd,ndsuf,nde in ALL_RANKS
#:for rk,rt,ri in ALL_KINDS_TYPES
module procedure la_${ri}$solve${ndsuf}$
#:endfor
#:endfor
end interface solve
character(*), parameter :: this = 'solve'
contains
elemental subroutine handle_gesv_info(info,lda,n,nrhs,err)
integer(ilp), intent(in) :: info,lda,n,nrhs
type(la_state), intent(out) :: err
! Process output
select case (info)
case (0)
! Success
case (-1)
err = la_state(this,LINALG_VALUE_ERROR,'invalid problem size n=',n)
case (-2)
err = la_state(this,LINALG_VALUE_ERROR,'invalid rhs size n=',nrhs)
case (-4)
err = la_state(this,LINALG_VALUE_ERROR,'invalid matrix size a=',[lda,n])
case (-7)
err = la_state(this,LINALG_ERROR,'invalid matrix size a=',[lda,n])
case (1:)
err = la_state(this,LINALG_ERROR,'singular matrix')
case default
err = la_state(this,LINALG_INTERNAL_ERROR,'catastrophic error')
end select
end subroutine handle_gesv_info
#:for nd,ndsuf,nde in ALL_RANKS
#:for rk,rt,ri in ALL_KINDS_TYPES
!> Linear system solve, ${ndsuf}$, ${rt}$
function la_${ri}$solve${ndsuf}$(a,b,overwrite_a,err) result(x)
!> Input matrix a[n,n]
${rt}$, intent(inout), target :: a(:,:)
!> Right hand side vector or array, b[n] or b[n,nrhs]
${rt}$, intent(in) :: b(${nd}$)
!> [optional] Can A data be overwritten and destroyed?
logical(lk), optional, intent(in) :: overwrite_a
!> [optional] state return flag. On error if not requested, the code will stop
type(la_state), optional, intent(out) :: err
!> Result array/matrix x[n] or x[n,nrhs]
${rt}$, allocatable, target :: x(${nd}$)
!> Local variables
type(la_state) :: err0
integer(ilp) :: lda,n,ldb,nrhs,info
integer(ilp), allocatable :: ipiv(:)
logical(lk) :: copy_a
${rt}$, pointer :: xmat(:,:),amat(:,:)
!> Problem sizes
lda = size(a,1,kind=ilp)
n = size(a,2,kind=ilp)
ldb = size(b,1,kind=ilp)
nrhs = size(b ,kind=ilp)/ldb
if (lda<1 .or. n<1 .or. ldb<1 .or. lda/=n .or. ldb/=n) then
err0 = la_state(this,LINALG_VALUE_ERROR,'invalid sizes: a=[',lda,',',n,'],',&
'b=[',ldb,',',nrhs,']')
allocate(x(${nde}$))
call err0%handle(err)
return
end if
! Can A be overwritten? By default, do not overwrite
if (present(overwrite_a)) then
copy_a = .not.overwrite_a
else
copy_a = .true._lk
endif
! Pivot indices
allocate(ipiv(n))
! Initialize a matrix temporary
if (copy_a) then
allocate(amat(lda,n),source=a)
else
amat => a
endif
! Initialize solution with the rhs
allocate(x,source=b)
xmat(1:n,1:nrhs) => x
! Solve system
call gesv(n,nrhs,amat,lda,ipiv,xmat,ldb,info)
! Process output
call handle_gesv_info(info,lda,n,nrhs,err0)
if (copy_a) deallocate(amat)
! Process output and return
call err0%handle(err)
end function la_${ri}$solve${ndsuf}$
#:endfor
#:endfor
end module la_solve