-
Notifications
You must be signed in to change notification settings - Fork 1
/
LinAlg.hs
266 lines (200 loc) · 8.36 KB
/
LinAlg.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies #-}
{-# LANGUAGE DataKinds, KindSignatures, PolyKinds, TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{- | This module provides an interface for specifying immutable
linear algebra computations.
This is done with the 'Matr' typeclass.
-}
module Numeric.LinAlg where
import Data.List (transpose, intercalate)
import Data.Type.Equality ((:~:))
import GHC.TypeLits (Nat)
import qualified Numeric.LinAlg.Vect as V
import Numeric.LinAlg.Vect (Vect)
import Numeric.LinAlg.SNat
data Dim = V Nat | M Nat Nat
-- | Infix, overloaded versions of functions for solving linear equations
class Solve (n :: Nat) (dim :: Dim) | dim -> n where
infixl 7 <\>
infixl 7 ^\
infixl 7 .\
infixl 7 \\
-- | 'linearSolve'
(<\>) :: Matr k arr => arr (M n n) k -> arr dim k -> arr dim k
-- | 'utriSolve'
(^\) :: Matr k arr => arr (M n n) k -> arr dim k -> arr dim k
-- | 'ltriSolve'
(.\) :: Matr k arr => arr (M n n) k -> arr dim k -> arr dim k
-- | 'posdefSolve'
(\\) :: Matr k arr => arr (M n n) k -> arr dim k -> arr dim k
instance Solve n (M n p) where
(<\>) = linearSolve
(^\) = utriSolve
(.\) = ltriSolve
(\\) = posdefSolve
instance Solve n (V n) where
x <\> y = asColVec (linearSolve x (asColMat y))
x ^\ y = asColVec (utriSolve x (asColMat y))
x .\ y = asColVec (ltriSolve x (asColMat y))
x \\ y = asColVec (posdefSolve x (asColMat y))
class Product (a :: Dim) (b :: Dim) where
type Prod a b :: Dim
infixl 7 ><
(><) :: Matr k arr => arr a k -> arr b k -> arr (Prod a b) k
instance Product (M m n) (M n p) where
type Prod (M m n) (M n p) = M m p
x >< y = mXm x y
instance Product (V n) (M n p) where
type Prod (V n) (M n p) = V p
x >< y = asColVec (trans (mXm (trans (asColMat x)) y))
instance Product (M m n) (V n ) where
type Prod (M m n) (V n ) = V m
x >< y = asColVec (mXm x (asColMat y))
-- | An instance of @'Matr' k v m@ means
-- that @v k@ is a vector type and @m k@ is a matrix type over the field @k@.
class ( Floating k, Scale k arr)
=> Matr (k :: *) (arr :: Dim -> * -> *) where
infixl 7 >.<
--
-- Data transfer
--
-- | Convert a list of elements to a vector.
fromVect :: Vect n k -> arr (V n) k
-- | Convert a vector to a list of its elements.
toVect :: arr (V n) k -> Vect n k
-- | Convert a row-major list of lists of elements (which should all have
-- the same length) to a matrix containing those elements.
fromVects :: Vect m (Vect n k) -> arr (M m n) k
-- | Convert a matrix to a list of its rows, each given as a list of
-- elements.
toVects :: arr (M m n) k -> Vect m (Vect n k)
-- | Convert a matrix to a list of its rows.
toRows :: arr (M m n) k -> Vect m (arr (V n) k)
-- | Convert a matrix to a list of its columns.
toColumns :: arr (M m n) k -> Vect n (arr (V m) k)
toColumns = toRows . trans
-- | Convert a list of vectors to a matrix having those vectors as rows.
fromRows :: Vect m (arr (V n) k) -> arr (M m n) k
-- | Convert a list of vectors to a matrix having those vectors as
-- columns.
fromColumns :: Vect n (arr (V m) k) -> arr (M m n) k
-- | Regard a vector as a matrix with a single column.
asColMat :: arr (V n) k -> arr (M n 1) k
-- | Convert a matrix which has only one column to a vector.
-- This function may have undefined behavior if the input matrix has more
-- than one column.
asColVec :: arr (M n 1) k -> arr (V n) k
asColVec = V.head . toColumns
-- | Produce a diagonal matrix with the given vector along its diagonal
-- (and zeros elsewhere).
fromDiag :: arr (V n) k -> arr (M n n) k
-- | Return a vector of elements along the diagonal of the matrix.
-- Does not necessarily fail if the matrix is not square.
takeDiag :: arr (M n n) k -> arr (V n) k
--
-- Core operations
--
-- | Dimension of a matrix (rows, columns).
dim :: arr (M m n) k -> (SNat m, SNat n)
-- | The number of rows in a matrix.
rows :: arr (M m n) k -> SNat m
rows = fst . dim
-- | The number of columns in a matrix.
cols :: arr (M m n) k -> SNat n
cols = snd . dim
-- | The length of a vector.
len :: arr (V n) k -> SNat n
-- | Transpose a matrix.
trans :: arr (M m n) k -> arr (M n m) k
-- | Construct the identity matrix of a given dimension.
ident :: SNat n -> arr (M n n) k
-- | Compute the outer product of two vectors.
outer :: arr (V m) k -> arr (V n) k -> arr (M m n) k
-- | Compute the dot product (i.e., inner product) of two vectors.
(>.<) :: arr (V n) k -> arr (V n) k -> k
-- Multiplication
mXm :: arr (M m n) k -> arr (M n p) k -> arr (M m p) k
-- | Compute the elementwise product (i.e., Hadamard product) of
-- two matrices.
elementwiseprod :: arr (M m n) k -> arr (M m n) k -> arr (M m n) k
--
-- Solving linear systems
--
-- | General matrix inverse.
inv :: arr (M n n) k -> arr (M n n) k
inv m = linearSolve m (ident (rows m))
-- | Inverse of a lower-triangular matrix.
invL :: arr (M n n) k -> arr (M n n) k
invL = inv
-- | Inverse of an upper-triangular matrix.
invU :: arr (M n n) k -> arr (M n n) k
invU = inv
--
-- Functions related to Cholesky decomposition
--
-- | Cholesky decomposition of a positive-definite symmetric matrix.
-- Returns lower triangular matrix of the decomposition. May not
-- necessarily zero out the upper portion of the matrix.
chol :: arr (M n n) k -> arr (M n n) k
-- | Invert a positive-definite symmetric system using a precomputed
-- Cholesky decomposition. That is, if @ l == 'chol' a @ and
-- @ b == 'cholInv' l @, then @ b @ is the inverse of @ a @.
cholInv :: arr (M n n) k -> arr (M n n) k
cholInv l = let il = invL l in trans il >< il
-- | Compute the log-determinant of a positive-definite symmetric matrix
-- using its precomputed Cholesky decomposition. That is, if
-- @ l == 'chol' a @ and @ d == 'cholLnDet' l @, then @ d @ is the
-- log-determinant of @ a @.
cholLnDet :: arr (M n n) k -> k
cholLnDet = (2*) . sum . map log . V.toList . toVect . takeDiag
--
-- Other functions
--
-- | If matrices @ a @ and @ b @ are symmetric, then @ trsymprod a b @ is
-- the trace of @ a '><' b @. Note that in this case, the trace of
-- @ a '><' b @ is just the sum of the elements in the element-wise
-- product of @ a @ and @ b @.
trsymprod :: arr (M n n) k -> arr (M n n) k -> k
-- | Create a vector of a given length whose elements all have the same
-- value.
constant :: k -> SNat n -> arr (V n) k
--
-- Solving linear equations
--
-- | Solve a general system. If there is some @x@ such that
-- @ m '><' x == b @, then @ x == 'linearSolve m b' @.
linearSolve :: arr (M n n) k -> arr (M n p) k -> arr (M n p) k
-- | Solve a lower-triangular system. If @l@ is a lower-triangular
-- matrix, then @ x == 'ltriSolve' l b @ means that @ l '><' x == b @.
ltriSolve :: arr (M n n) k -> arr (M n p) k -> arr (M n p) k
ltriSolve = linearSolve
-- | Solve a upper-triangular system. If @u@ is a upper-triangular
-- matrix, then @ x == 'utriSolve' u b @ means that @ u '><' x == b @.
utriSolve :: arr (M n n) k -> arr (M n p) k -> arr (M n p) k
utriSolve = linearSolve
-- | Solve a positive-definite symmetric system. That is, if @ a @ is a
-- positive-definite symmetric matrix and @ x == 'posdefSolve' a b @, then
-- @ a '><' x == b @.
posdefSolve :: arr (M n n) k -> arr (M n p) k -> arr (M n p) k
posdefSolve = linearSolve
-- | Solve a positive-definite symmetric system using a precomputed
-- Cholesky decomposition. If @ l == 'chol' a @ and
-- @ x == l \`'cholSolve'\` b @, then @ a '><' x == b @.
cholSolve :: arr (M n n) k -> arr (M n p) k -> arr (M n p) k
l `cholSolve` b = trans l ^\ (l .\ b) --This may be broken? It's not working well
-- | If the matrix is square, return 'Just' its dimension; otherwise,
-- 'Nothing'.
square :: Matr k arr => arr (M m n) k -> Maybe (m :~: n)
square m = let (i,j) = dim m in cmp i j
-- | Pretty-print a matrix. (Actual prettiness not guaranteed!)
showMat :: (Show k, Matr k arr) => arr (M m n) k -> String
showMat = intercalate "\n" . map ( intercalate "\t" . map show ) . toLists
toLists :: Matr k arr => arr (M m n) k -> [[k]]
toLists = map (V.toList) . V.toList . toVects
-- | Scalar multiplication for vector spaces.
class Scale k arr where
infixl 7 .*
(.*) :: k -> arr dim k -> arr dim k