forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinalg_test.py
2270 lines (2042 loc) · 84.1 KB
/
linalg_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the LAPAX linear algebra module."""
from functools import partial
import itertools
from typing import Iterator
from unittest import skipIf
import numpy as np
import scipy
import scipy.linalg
import scipy as osp
from absl.testing import absltest, parameterized
import jax
from jax import jit, grad, jvp, vmap
from jax import lax
from jax import numpy as jnp
from jax import scipy as jsp
from jax._src import config
from jax._src import deprecations
from jax._src.lax import linalg as lax_linalg
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.numpy.util import promote_dtypes_inexact
config.parse_flags_with_absl()
scipy_version = jtu.parse_version(scipy.version.version)
T = lambda x: np.swapaxes(x, -1, -2)
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex
int_types = jtu.dtypes.all_integer
def _is_required_cuda_version_satisfied(cuda_version):
version = xla_bridge.get_backend().platform_version
if version == "<unknown>" or "rocm" in version.split():
return False
else:
return int(version.split()[-1]) >= cuda_version
def _axis_for_ndim(ndim: int) -> Iterator[None | int | tuple[int, ...]]:
"""
Generate a range of valid axis arguments for a reduction over
an array with a given number of dimensions.
"""
yield from (None, ())
if ndim > 0:
yield from (0, (-1,))
if ndim > 1:
yield from (1, (0, 1), (-1, 0))
if ndim > 2:
yield (-1, 0, 1)
def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray:
"""scipy.linalg.toeplitz with v1.17+ batching semantics."""
if scipy_version >= (1, 17, 0):
return scipy.linalg.toeplitz(c, r)
elif r is None:
c = np.atleast_1d(c)
return np.vectorize(
scipy.linalg.toeplitz, signature="(m)->(m,m)", otypes=(c.dtype,))(c)
else:
c = np.atleast_1d(c)
r = np.atleast_1d(r)
return np.vectorize(
scipy.linalg.toeplitz, signature="(m),(n)->(m,n)", otypes=(np.result_type(c, r),))(c, r)
class NumpyLinalgTest(jtu.JaxTestCase):
@jtu.sample_product(
shape=[(1, 1), (4, 4), (2, 5, 5), (200, 200), (1000, 0, 0)],
dtype=float_types + complex_types,
upper=[True, False]
)
def testCholesky(self, shape, dtype, upper):
rng = jtu.rand_default(self.rng())
def args_maker():
factor_shape = shape[:-1] + (2 * shape[-1],)
a = rng(factor_shape, dtype)
return [np.matmul(a, jnp.conj(T(a)))]
jnp_fun = partial(jnp.linalg.cholesky, upper=upper)
def np_fun(x, upper=upper):
# Upper argument added in NumPy 2.0.0
if jtu.numpy_version() >= (2, 0, 0):
return np.linalg.cholesky(x, upper=upper)
result = np.linalg.cholesky(x)
if upper:
axes = list(range(x.ndim))
axes[-1], axes[-2] = axes[-2], axes[-1]
return np.transpose(result, axes).conj()
return result
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=1e-3)
self._CompileAndCheck(jnp_fun, args_maker)
if jnp.finfo(dtype).bits == 64:
jtu.check_grads(jnp.linalg.cholesky, args_maker(), order=2)
def testCholeskyGradPrecision(self):
rng = jtu.rand_default(self.rng())
a = rng((3, 3), np.float32)
a = np.dot(a, a.T)
jtu.assert_dot_precision(
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.cholesky), (a,), (a,))
@jtu.sample_product(
n=[0, 2, 3, 4, 5, 25], # TODO(mattjj): complex64 unstable on large sizes?
dtype=float_types + complex_types,
)
def testDet(self, n, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng((n, n), dtype)]
self._CheckAgainstNumpy(np.linalg.det, jnp.linalg.det, args_maker, tol=1e-3)
self._CompileAndCheck(jnp.linalg.det, args_maker,
rtol={np.float64: 1e-13, np.complex128: 1e-13})
def testDetOfSingularMatrix(self):
x = jnp.array([[-1., 3./2], [2./3, -1.]], dtype=np.float32)
self.assertAllClose(np.float32(0), jsp.linalg.det(x))
@jtu.sample_product(
shape=[(1, 1), (2, 2), (3, 3), (2, 2, 2), (2, 3, 3), (2, 4, 4), (5, 7, 7)],
dtype=float_types,
)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jtu.skip_on_devices("tpu")
def testDetGrad(self, shape, dtype):
rng = jtu.rand_default(self.rng())
a = rng(shape, dtype)
jtu.check_grads(jnp.linalg.det, (a,), 2, atol=1e-1, rtol=1e-1)
# make sure there are no NaNs when a matrix is zero
if len(shape) == 2:
jtu.check_grads(
jnp.linalg.det, (jnp.zeros_like(a),), 1, atol=1e-1, rtol=1e-1)
else:
a[0] = 0
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
def testDetGradIssue6121(self):
f = lambda x: jnp.linalg.det(x).sum()
x = jnp.ones((16, 1, 1))
jax.grad(f)(x)
jtu.check_grads(f, (x,), 2, atol=1e-1, rtol=1e-1)
def testDetGradOfSingularMatrixCorank1(self):
# Rank 2 matrix with nonzero gradient
a = jnp.array([[ 50, -30, 45],
[-30, 90, -81],
[ 45, -81, 81]], dtype=jnp.float32)
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)
# TODO(phawkins): Test sometimes produces NaNs on TPU.
@jtu.skip_on_devices("tpu")
def testDetGradOfSingularMatrixCorank2(self):
# Rank 1 matrix with zero gradient
b = jnp.array([[ 36, -42, 18],
[-42, 49, -21],
[ 18, -21, 9]], dtype=jnp.float32)
jtu.check_grads(jnp.linalg.det, (b,), 1, atol=1e-1, rtol=1e-1, eps=1e-1)
@jtu.sample_product(
m=[1, 5, 7, 23],
nq=zip([2, 4, 6, 36], [(1, 2), (2, 2), (1, 2, 3), (3, 3, 1, 4)]),
dtype=float_types,
)
def testTensorsolve(self, m, nq, dtype):
rng = jtu.rand_default(self.rng())
# According to numpy docs the shapes are as follows:
# Coefficient tensor (a), of shape b.shape + Q.
# And prod(Q) == prod(b.shape)
# Therefore, n = prod(q)
n, q = nq
b_shape = (n, m)
# To accomplish prod(Q) == prod(b.shape) we append the m extra dim
# to Q shape
Q = q + (m,)
args_maker = lambda: [
rng(b_shape + Q, dtype), # = a
rng(b_shape, dtype)] # = b
a, b = args_maker()
result = jnp.linalg.tensorsolve(*args_maker())
self.assertEqual(result.shape, Q)
self._CheckAgainstNumpy(np.linalg.tensorsolve,
jnp.linalg.tensorsolve, args_maker,
tol={np.float32: 1e-2, np.float64: 1e-3})
self._CompileAndCheck(jnp.linalg.tensorsolve,
args_maker,
rtol={np.float64: 1e-13})
def testTensorsolveAxes(self):
a_shape = (2, 1, 3, 6)
b_shape = (1, 6)
axes = (0, 2)
dtype = "float32"
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(a_shape, dtype), rng(b_shape, dtype)]
np_fun = partial(np.linalg.tensorsolve, axes=axes)
jnp_fun = partial(jnp.linalg.tensorsolve, axes=axes)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
[dict(dtype=dtype, method=method)
for dtype in float_types + complex_types
for method in (["lu"] if jnp.issubdtype(dtype, jnp.complexfloating)
else ["lu", "qr"])
],
shape=[(0, 0), (1, 1), (3, 3), (4, 4), (10, 10), (200, 200), (2, 2, 2),
(2, 3, 3), (3, 2, 2)],
)
def testSlogdet(self, shape, dtype, method):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
slogdet = partial(jnp.linalg.slogdet, method=method)
self._CheckAgainstNumpy(np.linalg.slogdet, slogdet, args_maker,
tol=1e-3)
self._CompileAndCheck(slogdet, args_maker)
@jtu.sample_product(
shape=[(1, 1), (4, 4), (5, 5), (2, 7, 7)],
dtype=float_types + complex_types,
)
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def testSlogdetGrad(self, shape, dtype):
rng = jtu.rand_default(self.rng())
a = rng(shape, dtype)
jtu.check_grads(jnp.linalg.slogdet, (a,), 2, atol=1e-1, rtol=2e-1)
def testIssue1213(self):
for n in range(5):
mat = jnp.array([np.diag(np.ones([5], dtype=np.float32))*(-.01)] * 2)
args_maker = lambda: [mat]
self._CheckAgainstNumpy(np.linalg.slogdet, jnp.linalg.slogdet, args_maker,
tol=1e-3)
@jtu.sample_product(
shape=[(0, 0), (4, 4), (5, 5), (50, 50), (2, 6, 6)],
dtype=float_types + complex_types,
compute_left_eigenvectors=[False, True],
compute_right_eigenvectors=[False, True],
)
@jtu.run_on_devices("cpu", "gpu")
def testEig(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
n = shape[-1]
args_maker = lambda: [rng(shape, dtype)]
# Norm, adjusted for dimension and type.
def norm(x):
norm = np.linalg.norm(x, axis=(-2, -1))
return norm / ((n + 1) * jnp.finfo(dtype).eps)
def check_right_eigenvectors(a, w, vr):
self.assertTrue(
np.all(norm(np.matmul(a, vr) - w[..., None, :] * vr) < 100))
def check_left_eigenvectors(a, w, vl):
rank = len(a.shape)
aH = jnp.conj(a.transpose(list(range(rank - 2)) + [rank - 1, rank - 2]))
wC = jnp.conj(w)
check_right_eigenvectors(aH, wC, vl)
a, = args_maker()
results = lax.linalg.eig(
a, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
w = results[0]
if compute_left_eigenvectors:
check_left_eigenvectors(a, w, results[1])
if compute_right_eigenvectors:
check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors])
self._CompileAndCheck(partial(jnp.linalg.eig), args_maker, rtol=1e-3)
@jtu.sample_product(
shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)],
dtype=float_types + complex_types,
compute_left_eigenvectors=[False, True],
compute_right_eigenvectors=[False, True],
)
@jtu.run_on_devices("cpu", "gpu")
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
a = jnp.full(shape, jnp.nan, dtype)
results = lax.linalg.eig(
a, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
for result in results:
self.assertTrue(np.all(np.isnan(result)))
@jtu.sample_product(
shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)],
dtype=float_types + complex_types,
)
@jtu.run_on_devices("cpu", "gpu")
def testEigvalsGrad(self, shape, dtype):
# This test sometimes fails for large matrices. I (@j-towns) suspect, but
# haven't checked, that might be because of perturbations causing the
# ordering of eigenvalues to change, which will trip up check_grads. So we
# just test on small-ish matrices.
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
tol = 1e-4 if dtype in (np.float64, np.complex128) else 1e-1
jtu.check_grads(lambda x: jnp.linalg.eigvals(x), (a,), order=1,
modes=['fwd', 'rev'], rtol=tol, atol=tol)
@jtu.sample_product(
shape=[(4, 4), (5, 5), (50, 50)],
dtype=float_types + complex_types,
)
@jtu.run_on_devices("cpu", "gpu")
def testEigvals(self, shape, dtype):
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
w1, _ = jnp.linalg.eig(a)
w2 = jnp.linalg.eigvals(a)
self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 2e-14})
@jtu.run_on_devices("cpu", "gpu")
def testEigvalsInf(self):
# https://github.com/jax-ml/jax/issues/2661
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
x = jnp.array([[jnp.inf]])
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
@jtu.sample_product(
shape=[(1, 1), (4, 4), (5, 5)],
dtype=float_types + complex_types,
)
@jtu.run_on_devices("cpu", "gpu")
def testEigBatching(self, shape, dtype):
if jtu.test_device_matches(["gpu"]) and jtu.jaxlib_version() <= (0, 4, 35):
self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
shape = (10,) + shape
args = rng(shape, dtype)
ws, vs = vmap(jnp.linalg.eig)(args)
self.assertTrue(np.all(np.linalg.norm(
np.matmul(args, vs) - ws[..., None, :] * vs) < 1e-3))
@jtu.sample_product(
n=[0, 4, 5, 50, 512],
dtype=float_types + complex_types,
lower=[True, False],
)
def testEigh(self, n, dtype, lower):
rng = jtu.rand_default(self.rng())
eps = np.finfo(dtype).eps
args_maker = lambda: [rng((n, n), dtype)]
uplo = "L" if lower else "U"
a, = args_maker()
a = (a + np.conj(a.T)) / 2
w, v = jnp.linalg.eigh(np.tril(a) if lower else np.triu(a),
UPLO=uplo, symmetrize_input=False)
w = w.astype(v.dtype)
tol = 2 * n * eps
self.assertAllClose(
np.eye(n, dtype=v.dtype),
np.matmul(np.conj(T(v)), v),
atol=tol,
rtol=tol,
)
with jax.numpy_rank_promotion('allow'):
tol = 100 * eps
self.assertLessEqual(
np.linalg.norm(np.matmul(a, v) - w * v), tol * np.linalg.norm(a)
)
self._CompileAndCheck(
partial(jnp.linalg.eigh, UPLO=uplo), args_maker, rtol=eps
)
# Compare eigenvalues against Numpy using double precision. We do not compare
# eigenvectors because they are not uniquely defined, but the two checks above
# guarantee that that they satisfy the conditions for being eigenvectors.
double_type = dtype
if dtype == np.float32:
double_type = np.float64
if dtype == np.complex64:
double_type = np.complex128
w_np = np.linalg.eigvalsh(a.astype(double_type))
tol = 8 * eps
self.assertAllClose(
w_np.astype(w.dtype), w, atol=tol * np.linalg.norm(a), rtol=tol
)
@jtu.sample_product(
start=[0, 1, 63, 64, 65, 255],
end=[1, 63, 64, 65, 256],
)
@jtu.run_on_devices("tpu") # TODO(rmlarsen: enable on other devices)
def testEighSubsetByIndex(self, start, end):
if start >= end:
return
dtype = np.float32
n = 256
rng = jtu.rand_default(self.rng())
eps = np.finfo(dtype).eps
args_maker = lambda: [rng((n, n), dtype)]
subset_by_index = (start, end)
k = end - start
(a,) = args_maker()
a = (a + np.conj(a.T)) / 2
v, w = lax.linalg.eigh(
a, symmetrize_input=False, subset_by_index=subset_by_index
)
w = w.astype(v.dtype)
self.assertEqual(v.shape, (n, k))
self.assertEqual(w.shape, (k,))
with jax.numpy_rank_promotion("allow"):
tol = 200 * eps
self.assertLessEqual(
np.linalg.norm(np.matmul(a, v) - w * v), tol * np.linalg.norm(a)
)
tol = 3 * n * eps
self.assertAllClose(
np.eye(k, dtype=v.dtype),
np.matmul(np.conj(T(v)), v),
atol=tol,
rtol=tol,
)
self._CompileAndCheck(partial(jnp.linalg.eigh), args_maker, rtol=eps)
# Compare eigenvalues against Numpy. We do not compare eigenvectors because
# they are not uniquely defined, but the two checks above guarantee that
# that they satisfy the conditions for being eigenvectors.
double_type = dtype
if dtype == np.float32:
double_type = np.float64
if dtype == np.complex64:
double_type = np.complex128
w_np = np.linalg.eigvalsh(a.astype(double_type))[
subset_by_index[0] : subset_by_index[1]
]
tol = 20 * eps
self.assertAllClose(
w_np.astype(w.dtype), w, atol=tol * np.linalg.norm(a), rtol=tol
)
def testEighZeroDiagonal(self):
a = np.array([[0., -1., -1., 1.],
[-1., 0., 1., -1.],
[-1., 1., 0., -1.],
[1., -1., -1., 0.]], dtype=np.float32)
w, v = jnp.linalg.eigh(a)
w = w.astype(v.dtype)
eps = jnp.finfo(a.dtype).eps
with jax.numpy_rank_promotion('allow'):
self.assertLessEqual(
np.linalg.norm(np.matmul(a, v) - w * v), 2 * eps * np.linalg.norm(a)
)
def testEighTinyNorm(self):
rng = jtu.rand_default(self.rng())
a = rng((300, 300), dtype=np.float32)
eps = jnp.finfo(a.dtype).eps
a = eps * (a + np.conj(a.T))
w, v = jnp.linalg.eigh(a)
w = w.astype(v.dtype)
with jax.numpy_rank_promotion("allow"):
self.assertLessEqual(
np.linalg.norm(np.matmul(a, v) - w * v), 80 * eps * np.linalg.norm(a)
)
@jtu.sample_product(
rank=[1, 3, 299],
)
def testEighRankDeficient(self, rank):
rng = jtu.rand_default(self.rng())
eps = jnp.finfo(np.float32).eps
a = rng((300, rank), dtype=np.float32)
a = a @ np.conj(a.T)
w, v = jnp.linalg.eigh(a)
w = w.astype(v.dtype)
with jax.numpy_rank_promotion("allow"):
self.assertLessEqual(
np.linalg.norm(np.matmul(a, v) - w * v),
85 * eps * np.linalg.norm(a),
)
@jtu.sample_product(
n=[0, 4, 5, 50, 512],
dtype=float_types + complex_types,
lower=[True, False],
)
def testEighIdentity(self, n, dtype, lower):
tol = np.finfo(dtype).eps
uplo = "L" if lower else "U"
a = jnp.eye(n, dtype=dtype)
w, v = jnp.linalg.eigh(a, UPLO=uplo, symmetrize_input=False)
w = w.astype(v.dtype)
self.assertLessEqual(
np.linalg.norm(np.eye(n) - np.matmul(np.conj(T(v)), v)), tol
)
with jax.numpy_rank_promotion('allow'):
self.assertLessEqual(np.linalg.norm(np.matmul(a, v) - w * v),
tol * np.linalg.norm(a))
@jtu.sample_product(
shape=[(4, 4), (5, 5), (50, 50)],
dtype=float_types + complex_types,
)
def testEigvalsh(self, shape, dtype):
rng = jtu.rand_default(self.rng())
n = shape[-1]
def args_maker():
a = rng((n, n), dtype)
a = (a + np.conj(a.T)) / 2
return [a]
self._CheckAgainstNumpy(
np.linalg.eigvalsh, jnp.linalg.eigvalsh, args_maker, tol=2e-5
)
@jtu.sample_product(
shape=[(1, 1), (4, 4), (5, 5), (50, 50), (2, 10, 10)],
dtype=float_types + complex_types,
lower=[True, False],
)
def testEighGrad(self, shape, dtype, lower):
rng = jtu.rand_default(self.rng())
self.skipTest("Test fails with numeric errors.")
uplo = "L" if lower else "U"
a = rng(shape, dtype)
a = (a + np.conj(T(a))) / 2
ones = np.ones((a.shape[-1], a.shape[-1]), dtype=dtype)
a *= np.tril(ones) if lower else np.triu(ones)
# Gradient checks will fail without symmetrization as the eigh jvp rule
# is only correct for tangents in the symmetric subspace, whereas the
# checker checks against unconstrained (co)tangents.
if dtype not in complex_types:
f = partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)
else: # only check eigenvalue grads for complex matrices
f = lambda a: partial(jnp.linalg.eigh, UPLO=uplo, symmetrize_input=True)(a)[0]
jtu.check_grads(f, (a,), 2, rtol=1e-5)
@jtu.sample_product(
shape=[(1, 1), (4, 4), (5, 5), (50, 50)],
dtype=complex_types,
lower=[True, False],
eps=[1e-5],
)
def testEighGradVectorComplex(self, shape, dtype, lower, eps):
rng = jtu.rand_default(self.rng())
# Special case to test for complex eigenvector grad correctness.
# Exact eigenvector coordinate gradients are hard to test numerically for complex
# eigensystem solvers given the extra degrees of per-eigenvector phase freedom.
# Instead, we numerically verify the eigensystem properties on the perturbed
# eigenvectors. You only ever want to optimize eigenvector directions, not coordinates!
uplo = "L" if lower else "U"
a = rng(shape, dtype)
a = (a + np.conj(a.T)) / 2
a = np.tril(a) if lower else np.triu(a)
a_dot = eps * rng(shape, dtype)
a_dot = (a_dot + np.conj(a_dot.T)) / 2
a_dot = np.tril(a_dot) if lower else np.triu(a_dot)
# evaluate eigenvector gradient and groundtruth eigensystem for perturbed input matrix
f = partial(jnp.linalg.eigh, UPLO=uplo)
(w, v), (dw, dv) = jvp(f, primals=(a,), tangents=(a_dot,))
self.assertTrue(jnp.issubdtype(w.dtype, jnp.floating))
self.assertTrue(jnp.issubdtype(dw.dtype, jnp.floating))
new_a = a + a_dot
new_w, new_v = f(new_a)
new_a = (new_a + np.conj(new_a.T)) / 2
new_w = new_w.astype(new_a.dtype)
# Assert rtol eigenvalue delta between perturbed eigenvectors vs new true eigenvalues.
RTOL = 1e-2
with jax.numpy_rank_promotion('allow'):
assert np.max(
np.abs((np.diag(np.dot(np.conj((v+dv).T), np.dot(new_a,(v+dv)))) - new_w) / new_w)) < RTOL
# Redundant to above, but also assert rtol for eigenvector property with new true eigenvalues.
assert np.max(
np.linalg.norm(np.abs(new_w*(v+dv) - np.dot(new_a, (v+dv))), axis=0) /
np.linalg.norm(np.abs(new_w*(v+dv)), axis=0)
) < RTOL
def testEighGradPrecision(self):
rng = jtu.rand_default(self.rng())
a = rng((3, 3), np.float32)
jtu.assert_dot_precision(
lax.Precision.HIGHEST, partial(jvp, jnp.linalg.eigh), (a,), (a,))
@jtu.sample_product(
shape=[(1, 1), (4, 4), (5, 5), (300, 300)],
dtype=float_types + complex_types,
)
def testEighBatching(self, shape, dtype):
rng = jtu.rand_default(self.rng())
shape = (10,) + shape
args = rng(shape, dtype)
args = (args + np.conj(T(args))) / 2
ws, vs = vmap(jsp.linalg.eigh)(args)
ws = ws.astype(vs.dtype)
norm = np.max(np.linalg.norm(np.matmul(args, vs) - ws[..., None, :] * vs))
self.assertLess(norm, 1.4e-2)
@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
)
def testLuPivotsToPermutation(self, shape, dtype):
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
pivots = jnp.broadcast_to(pivots, shape)
actual = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
expected = jnp.arange(permutation_size - 1, -1, -1, dtype=dtype)
expected = jnp.broadcast_to(expected, actual.shape)
self.assertArraysEqual(actual, expected)
@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
)
def testLuPivotsToPermutationBatching(self, shape, dtype):
shape = (10,) + shape
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
pivots = jnp.broadcast_to(pivots, shape)
batched_fn = vmap(
lambda x: lax.linalg.lu_pivots_to_permutation(x, permutation_size))
actual = batched_fn(pivots)
expected = jnp.arange(permutation_size - 1, -1, -1, dtype=dtype)
expected = jnp.broadcast_to(expected, actual.shape)
self.assertArraysEqual(actual, expected)
@jtu.sample_product(
[dict(axis=axis, shape=shape, ord=ord)
for axis, shape in [
(None, (1,)), (None, (7,)), (None, (5, 8)),
(0, (9,)), (0, (4, 5)), ((1,), (10, 7, 3)), ((-2,), (4, 8)),
(-1, (6, 3)), ((0, 2), (3, 4, 5)), ((2, 0), (7, 8, 9)),
(None, (7, 8, 11))]
for ord in (
[None] if axis is None and len(shape) > 2
else [None, 0, 1, 2, 3, -1, -2, -3, jnp.inf, -jnp.inf]
if (axis is None and len(shape) == 1) or
isinstance(axis, int) or
(isinstance(axis, tuple) and len(axis) == 1)
else [None, 'fro', 1, 2, -1, -2, jnp.inf, -jnp.inf, 'nuc'])
],
keepdims=[False, True],
dtype=float_types + complex_types,
)
def testNorm(self, shape, dtype, ord, axis, keepdims):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fn = partial(np.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
jnp_fn = partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)
def testStringInfNorm(self):
err, msg = ValueError, r"Invalid order 'inf' for vector norm."
with self.assertRaisesRegex(err, msg):
jnp.linalg.norm(jnp.array([1.0, 2.0, 3.0]), ord="inf")
@jtu.sample_product(
shape=[(2, 3), (4, 2, 3), (2, 3, 4, 5)],
dtype=float_types + complex_types,
keepdims=[True, False],
ord=[1, -1, 2, -2, np.inf, -np.inf, 'fro', 'nuc'],
)
def testMatrixNorm(self, shape, dtype, keepdims, ord):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
np_fn = partial(np.linalg.norm, ord=ord, keepdims=keepdims, axis=(-2, -1))
else:
np_fn = partial(np.linalg.matrix_norm, ord=ord, keepdims=keepdims)
jnp_fn = partial(jnp.linalg.matrix_norm, ord=ord, keepdims=keepdims)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
self._CompileAndCheck(jnp_fn, args_maker)
@skipIf(jtu.numpy_version() < (2, 0, 0), "np.linalg.vector_norm requires NumPy 2.0")
@jtu.sample_product(
[
dict(shape=shape, axis=axis)
for shape in [(3,), (3, 4), (2, 3, 4, 5)]
for axis in _axis_for_ndim(len(shape))
],
dtype=float_types + complex_types,
keepdims=[True, False],
ord=[1, -1, 2, -2, np.inf, -np.inf],
)
def testVectorNorm(self, shape, dtype, keepdims, axis, ord):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
np_fn = partial(np.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
jnp_fn = partial(jnp.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
tol = 1E-3 if jtu.test_device_matches(['tpu']) else None
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
# jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here.
@jtu.sample_product(
[
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axis=0),
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axis=1),
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axis=-1),
],
dtype=int_types + float_types + complex_types
)
@jax.default_matmul_precision("float32")
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testVecdot(self, lhs_shape, rhs_shape, axis, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
np_fn = jtu.numpy_vecdot if jtu.numpy_version() < (2, 0, 0) else np.linalg.vecdot
np_fn = jtu.promote_like_jnp(partial(np_fn, axis=axis))
jnp_fn = partial(jnp.linalg.vecdot, axis=axis)
tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12,
np.complex128: 1e-12}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
# smoke-test for optional kwargs.
jnp_fn = partial(jnp.linalg.vecdot, axis=axis,
precision=lax.Precision.HIGHEST,
preferred_element_type=dtype)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
# jnp.linalg.matmul is an alias of jnp.matmul; do a minimal test here.
@jtu.sample_product(
[
dict(lhs_shape=(3,), rhs_shape=(3,)), # vec-vec
dict(lhs_shape=(2, 3), rhs_shape=(3,)), # mat-vec
dict(lhs_shape=(3,), rhs_shape=(3, 4)), # vec-mat
dict(lhs_shape=(2, 3), rhs_shape=(3, 4)), # mat-mat
],
dtype=float_types + complex_types
)
@jax.default_matmul_precision("float32")
def testMatmul(self, lhs_shape, rhs_shape, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
np_fn = jtu.promote_like_jnp(
np.matmul if jtu.numpy_version() < (2, 0, 0) else np.linalg.matmul)
jnp_fn = jnp.linalg.matmul
tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12,
np.complex128: 1e-12}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
# smoke-test for optional kwargs.
jnp_fn = partial(jnp.linalg.matmul,
precision=lax.Precision.HIGHEST,
preferred_element_type=dtype)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
# jnp.linalg.tensordot is an alias of jnp.tensordot; do a minimal test here.
@jtu.sample_product(
[
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axes=0),
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axes=1),
dict(lhs_shape=(2, 2, 2), rhs_shape=(2, 2), axes=2),
],
dtype=float_types + complex_types
)
@jax.default_matmul_precision("float32")
def testTensordot(self, lhs_shape, rhs_shape, axes, dtype):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)]
np_fn = jtu.promote_like_jnp(
partial(
np.tensordot if jtu.numpy_version() < (2, 0, 0) else np.linalg.tensordot,
axes=axes))
jnp_fn = partial(jnp.linalg.tensordot, axes=axes)
tol = {np.float16: 1e-2, np.float32: 2e-2, np.float64: 1e-12,
np.complex128: 1e-12}
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
# smoke-test for optional kwargs.
jnp_fn = partial(jnp.linalg.tensordot, axes=axes,
precision=lax.Precision.HIGHEST,
preferred_element_type=dtype)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
@jtu.sample_product(
[
dict(m=m, n=n, full_matrices=full_matrices, hermitian=hermitian)
for (m, n), full_matrices in (
list(
itertools.product(
itertools.product([0, 2, 7, 29, 32, 53], repeat=2),
[False, True],
)
)
+
# Test cases that ensure we are economical when computing the SVD
# and its gradient. If we form a 400kx400k matrix explicitly we
# will OOM.
[((400000, 2), False), ((2, 400000), False)]
)
for hermitian in ([False, True] if m == n else [False])
],
b=[(), (3,), (2, 3)],
dtype=float_types + complex_types,
compute_uv=[False, True],
)
@jax.default_matmul_precision("float32")
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(b + (m, n), dtype)]
def compute_max_backward_error(operand, reconstructed_operand):
error_norm = np.linalg.norm(operand - reconstructed_operand,
axis=(-2, -1))
backward_error = (error_norm /
np.linalg.norm(operand, axis=(-2, -1)))
max_backward_error = np.amax(backward_error)
return max_backward_error
tol = 100 * jnp.finfo(dtype).eps
reconstruction_tol = 2 * tol
unitariness_tol = 3 * tol
a, = args_maker()
if hermitian:
a = a + np.conj(T(a))
out = jnp.linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv,
hermitian=hermitian)
if compute_uv:
# Check the reconstructed matrices
out = list(out)
out[1] = out[1].astype(out[0].dtype) # for strict dtype promotion.
if m and n:
if full_matrices:
k = min(m, n)
if m < n:
max_backward_error = compute_max_backward_error(
a, np.matmul(out[1][..., None, :] * out[0], out[2][..., :k, :]))
self.assertLess(max_backward_error, reconstruction_tol)
else:
max_backward_error = compute_max_backward_error(
a, np.matmul(out[1][..., None, :] * out[0][..., :, :k], out[2]))
self.assertLess(max_backward_error, reconstruction_tol)
else:
max_backward_error = compute_max_backward_error(
a, np.matmul(out[1][..., None, :] * out[0], out[2]))
self.assertLess(max_backward_error, reconstruction_tol)
# Check the unitary properties of the singular vector matrices.
unitary_mat = np.real(np.matmul(np.conj(T(out[0])), out[0]))
eye_slice = np.eye(out[0].shape[-1], dtype=unitary_mat.dtype)
self.assertAllClose(np.broadcast_to(eye_slice, b + eye_slice.shape),
unitary_mat, rtol=unitariness_tol,
atol=unitariness_tol)
if m >= n:
unitary_mat = np.real(np.matmul(np.conj(T(out[2])), out[2]))
eye_slice = np.eye(out[2].shape[-1], dtype=unitary_mat.dtype)
self.assertAllClose(np.broadcast_to(eye_slice, b + eye_slice.shape),
unitary_mat, rtol=unitariness_tol,
atol=unitariness_tol)
else:
unitary_mat = np.real(np.matmul(out[2], np.conj(T(out[2]))))
eye_slice = np.eye(out[2].shape[-2], dtype=unitary_mat.dtype)
self.assertAllClose(np.broadcast_to(eye_slice, b + eye_slice.shape),
unitary_mat, rtol=unitariness_tol,
atol=unitariness_tol)
else:
self.assertTrue(np.allclose(np.linalg.svd(a, compute_uv=False),
np.asarray(out), atol=1e-4, rtol=1e-4))
self._CompileAndCheck(partial(jnp.linalg.svd, full_matrices=full_matrices,
compute_uv=compute_uv),
args_maker)
if not compute_uv and a.size < 100000:
svd = partial(jnp.linalg.svd, full_matrices=full_matrices,
compute_uv=compute_uv)
# TODO(phawkins): these tolerances seem very loose.
if dtype == np.complex128:
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=1e-4, atol=1e-4,
eps=1e-8)
else:
jtu.check_jvp(svd, partial(jvp, svd), (a,), rtol=5e-2, atol=2e-1)
if compute_uv and (not full_matrices):
b, = args_maker()
def f(x):
u, s, v = jnp.linalg.svd(
a + x * b,
full_matrices=full_matrices,
compute_uv=compute_uv)
vdiag = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')
return jnp.matmul(jnp.matmul(u, vdiag(s).astype(u.dtype)), v).real
_, t_out = jvp(f, (1.,), (1.,))
if dtype == np.complex128:
atol = 2e-13
else:
atol = 6e-4
self.assertArraysAllClose(t_out, b.real, atol=atol)
def testJspSVDBasic(self):
# since jax.scipy.linalg.svd is almost the same as jax.numpy.linalg.svd
# do not check it functionality here
jsp.linalg.svd(np.ones((2, 2), dtype=np.float32))
@jtu.sample_product(
shape=[(0, 2), (2, 0), (3, 4), (3, 3), (4, 3)],
dtype=[np.float32],
mode=["reduced", "r", "full", "complete", "raw"],
)
def testNumpyQrModes(self, shape, dtype, mode):
rng = jtu.rand_default(self.rng())
jnp_func = partial(jax.numpy.linalg.qr, mode=mode)
np_func = partial(np.linalg.qr, mode=mode)
if mode == "full":
np_func = jtu.ignore_warning(category=DeprecationWarning, message="The 'full' option.*")(np_func)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_func, jnp_func, args_maker, rtol=1e-5, atol=1e-5,
check_dtypes=(mode != "raw"))
self._CompileAndCheck(jnp_func, args_maker)
@jtu.sample_product(
shape=[(0, 0), (2, 0), (0, 2), (3, 3), (3, 4), (2, 10, 5),
(2, 200, 100), (64, 16, 5), (33, 7, 3), (137, 9, 5), (20000, 2, 2)],
dtype=float_types + complex_types,
full_matrices=[False, True],
)
@jax.default_matmul_precision("float32")
def testQr(self, shape, dtype, full_matrices):
if (jtu.test_device_matches(["cuda"]) and
_is_required_cuda_version_satisfied(12000)):
self.skipTest("Triggers a bug in cuda-12 b/287345077")
rng = jtu.rand_default(self.rng())
m, n = shape[-2:]
if full_matrices:
mode, k = "complete", m
else:
mode, k = "reduced", min(m, n)
a = rng(shape, dtype)
lq, lr = jnp.linalg.qr(a, mode=mode)
# np.linalg.qr doesn't support batch dimensions. But it seems like an
# inevitable extension so we support it in our version.
nq = np.zeros(shape[:-2] + (m, k), dtype)
nr = np.zeros(shape[:-2] + (k, n), dtype)
for index in np.ndindex(*shape[:-2]):
nq[index], nr[index] = np.linalg.qr(a[index], mode=mode)
max_rank = max(m, n)
# Norm, adjusted for dimension and type.
def norm(x):