forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lapack.pyx
1814 lines (1537 loc) · 60.2 KB
/
lapack.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: language_level=2
# distutils: language = c++
# Shims that allow the XLA CPU backend to call scipy-provided LAPACK kernels
# via CustomCallWithLayout.
from __future__ import print_function
cdef extern from "<cmath>" namespace "std":
bint isnan(float x) nogil
bint isnan(double x) nogil
from libc.stdlib cimport malloc, free
from libc.stdint cimport int32_t, int64_t, uint8_t
from libc.string cimport memcpy
from libcpp cimport bool as bool_t
from libcpp.string cimport string
from cpython.pycapsule cimport PyCapsule_New
from scipy.linalg.cython_blas cimport strsm, dtrsm, ctrsm, ztrsm
from scipy.linalg.cython_lapack cimport sgetrf, dgetrf, cgetrf, zgetrf
from scipy.linalg.cython_lapack cimport sgeqrf, dgeqrf, cgeqrf, zgeqrf
from scipy.linalg.cython_lapack cimport sorgqr, dorgqr, cungqr, zungqr
from scipy.linalg.cython_lapack cimport spotrf, dpotrf, cpotrf, zpotrf
from scipy.linalg.cython_lapack cimport sgesdd, dgesdd, cgesdd, zgesdd
from scipy.linalg.cython_lapack cimport ssyevd, dsyevd, cheevd, zheevd
from scipy.linalg.cython_lapack cimport sgeev, dgeev, cgeev, zgeev
import numpy as np
from jaxlib import xla_client
_ops = xla_client.ops
Shape = xla_client.Shape
cdef int _int32_max = 0x7FFFFFFF;
cdef register_cpu_custom_call_target(fn_name, void* fn):
cdef const char* name = "xla._CUSTOM_CALL_TARGET"
xla_client.register_custom_call_target(fn_name, PyCapsule_New(fn, name, NULL))
def _constant_s32_scalar(c, x):
return _ops.Constant(c, np.int32(x))
# TODO(phawkins): remove after we no longer need to support old jax releases.
def _unpack_builder(c):
# If `c` is a ComputationBuilder object, extracts the underlying XlaBuilder.
return getattr(c, "_builder", c)
# TODO(phawkins): it would be nice to avoid duplicating code for each type.
# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
# triangular solve
cdef void blas_strsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef int batch = (<int32_t*>(data[6]))[0]
cdef float* alpha = <float*>(data[7])
cdef float* a = <float*>(data[8])
cdef float* b = <float*>(data[9])
cdef float* x = <float*>(out)
if x != b:
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
strsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_strsm", <void*>(blas_strsm))
cdef void blas_dtrsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef int batch = (<int32_t*>(data[6]))[0]
cdef double* alpha = <double*>(data[7])
cdef double* a = <double*>(data[8])
cdef double* b = <double*>(data[9])
cdef double* x = <double*>(out)
if x != b:
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
dtrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_dtrsm", <void*>(blas_dtrsm))
cdef void blas_ctrsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef int batch = (<int32_t*>(data[6]))[0]
cdef float complex* alpha = <float complex*>(data[7])
cdef float complex* a = <float complex*>(data[8])
cdef float complex* b = <float complex*>(data[9])
cdef float complex* x = <float complex*>(out)
if x != b:
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
ctrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_ctrsm", <void*>(blas_ctrsm))
cdef void blas_ztrsm(void* out, void** data) nogil:
cdef int32_t left_side = (<int32_t*>(data[0]))[0]
cdef int32_t lower = (<int32_t*>(data[1]))[0]
cdef int32_t trans_a = (<int32_t*>(data[2]))[0]
cdef int32_t diag = (<int32_t*>(data[3]))[0]
cdef int m = (<int32_t*>(data[4]))[0]
cdef int n = (<int32_t*>(data[5]))[0]
cdef int batch = (<int32_t*>(data[6]))[0]
cdef double complex* alpha = <double complex*>(data[7])
cdef double complex* a = <double complex*>(data[8])
cdef double complex* b = <double complex*>(data[9])
cdef double complex* x = <double complex*>(out)
if x != b:
memcpy(x, b, <int64_t>(batch) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
cdef char cside = 'L' if left_side else 'R'
cdef char cuplo = 'L' if lower else 'U'
cdef char ctransa = 'N'
if trans_a == 1:
ctransa = 'T'
elif trans_a == 2:
ctransa = 'C'
cdef char cdiag = 'U' if diag else 'N'
cdef int lda = m if left_side else n
cdef int ldb = m
cdef int64_t x_plus = <int64_t>(m) * <int64_t>(n)
cdef int64_t a_plus = <int64_t>(lda) * <int64_t>(lda)
for _ in range(batch):
ztrsm(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb)
x += x_plus
a += a_plus
register_cpu_custom_call_target(b"blas_ztrsm", <void*>(blas_ztrsm))
def trsm(c, alpha, a, b, left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False):
c = _unpack_builder(c)
a_shape = c.get_shape(a)
b_shape = c.get_shape(b)
dtype = b_shape.element_type()
dims = b_shape.dimensions()
m, n = dims[-2:]
k = m if left_side else n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
num_b = 1
for d in batch_dims:
num_b *= d
if batch_dims + (k, k) != a_shape.dimensions() or a_shape.element_type() != dtype:
raise ValueError("Argument mismatch for trsm, got {} and {}".format(
a_shape, b_shape))
if dtype == np.float32:
fn = b"blas_strsm"
elif dtype == np.float64:
fn = b"blas_dtrsm"
elif dtype == np.complex64:
fn = b"blas_ctrsm"
elif dtype == np.complex128:
fn = b"blas_ztrsm"
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
if conj_a and not trans_a:
raise NotImplementedError("Conjugation without transposition not supported")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
return _ops.CustomCallWithLayout(
c, fn,
operands=(
_constant_s32_scalar(c, int(left_side)),
_constant_s32_scalar(c, int(lower)),
_constant_s32_scalar(c, (2 if conj_a else 1) if trans_a else 0),
_constant_s32_scalar(c, int(diag)),
_constant_s32_scalar(c, m),
_constant_s32_scalar(c, n),
_constant_s32_scalar(c, num_b),
alpha, a, b),
shape_with_layout=Shape.array_shape(dtype, b_shape.dimensions(), layout),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, (), ()),
Shape.array_shape(dtype, a_shape.dimensions(), layout),
Shape.array_shape(dtype, b_shape.dimensions(), layout),
))
jax_trsm = trsm
# ?getrf: LU decomposition
cdef void lapack_sgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float* a_in = <float*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
for i in range(b):
sgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_sgetrf", <void*>(lapack_sgetrf))
cdef void lapack_dgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double* a_in = <double*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
for i in range(b):
dgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_dgetrf", <void*>(lapack_dgetrf))
cdef void lapack_cgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float complex* a_in = <float complex*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
for i in range(b):
cgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_cgetrf", <void*>(lapack_cgetrf))
cdef void lapack_zgetrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double complex* a_in = <double complex*>(data[3])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef int* ipiv = <int*>(out[1])
cdef int* info = <int*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
for i in range(b):
zgetrf(&m, &n, a_out, &m, ipiv, info)
a_out += m * n
ipiv += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_zgetrf", <void*>(lapack_zgetrf))
def getrf(c, a):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgetrf"
elif dtype == np.float64:
fn = b"lapack_dgetrf"
elif dtype == np.complex64:
fn = b"lapack_cgetrf"
elif dtype == np.complex128:
fn = b"lapack_zgetrf"
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = _ops.CustomCallWithLayout(
c, fn,
operands=(
_constant_s32_scalar(c, b),
_constant_s32_scalar(c, m),
_constant_s32_scalar(c, n),
a),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(
np.dtype(np.int32),
batch_dims + (min(m, n),),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
))
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
# ?geqrf: QR decomposition
cdef int lapack_sgeqrf_workspace(int m, int n):
cdef float work = 0
cdef int lwork = -1
cdef int info = 0
sgeqrf(&m, &n, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work) if info == 0 else -1
cdef void lapack_sgeqrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int lwork = (<int32_t*>(data[3]))[0]
cdef const float* a_in = <float*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef float* tau = <float*>(out[1])
cdef int* info = <int*>(out[2])
cdef float* work = <float*>(out[3])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
for i in range(b):
sgeqrf(&m, &n, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_sgeqrf", <void*>(lapack_sgeqrf))
cdef int lapack_dgeqrf_workspace(int m, int n):
cdef double work = 0
cdef int lwork = -1
cdef int info = 0
dgeqrf(&m, &n, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work) if info == 0 else -1
cdef void lapack_dgeqrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int lwork = (<int32_t*>(data[3]))[0]
cdef const double* a_in = <double*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef double* tau = <double*>(out[1])
cdef int* info = <int*>(out[2])
cdef double* work = <double*>(out[3])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
for i in range(b):
dgeqrf(&m, &n, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_dgeqrf", <void*>(lapack_dgeqrf))
cdef int lapack_cgeqrf_workspace(int m, int n):
cdef float complex work = 0
cdef int lwork = -1
cdef int info = 0
cgeqrf(&m, &n, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work.real) if info == 0 else -1
cdef void lapack_cgeqrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int lwork = (<int32_t*>(data[3]))[0]
cdef const float complex* a_in = <float complex*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef float complex* tau = <float complex*>(out[1])
cdef int* info = <int*>(out[2])
cdef float complex* work = <float complex*>(out[3])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
for i in range(b):
cgeqrf(&m, &n, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_cgeqrf", <void*>(lapack_cgeqrf))
cdef int lapack_zgeqrf_workspace(int m, int n):
cdef double complex work = 0
cdef int lwork = -1
cdef int info = 0
zgeqrf(&m, &n, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work.real) if info == 0 else -1
cdef void lapack_zgeqrf(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int lwork = (<int32_t*>(data[3]))[0]
cdef const double complex* a_in = <double complex*>(data[4])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef double complex* tau = <double complex*>(out[1])
cdef int* info = <int*>(out[2])
cdef double complex* work = <double complex*>(out[3])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
for i in range(b):
zgeqrf(&m, &n, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += min(m, n)
info += 1
register_cpu_custom_call_target(b"lapack_zgeqrf", <void*>(lapack_zgeqrf))
def geqrf(c, a):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
if dtype == np.float32:
fn = b"lapack_sgeqrf"
lwork = lapack_sgeqrf_workspace(m, n)
elif dtype == np.float64:
fn = b"lapack_dgeqrf"
lwork = lapack_dgeqrf_workspace(m, n)
elif dtype == np.complex64:
fn = b"lapack_cgeqrf"
lwork = lapack_cgeqrf_workspace(m, n)
elif dtype == np.complex128:
fn = b"lapack_zgeqrf"
lwork = lapack_zgeqrf_workspace(m, n)
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = _ops.CustomCallWithLayout(
c, fn,
operands=(
_constant_s32_scalar(c, b),
_constant_s32_scalar(c, m),
_constant_s32_scalar(c, n),
_constant_s32_scalar(c, lwork),
a,
),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(
np.dtype(dtype),
batch_dims + (min(m, n),),
tuple(range(num_bd, -1, -1))),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(np.dtype(dtype), (lwork,), (0,)),
)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
))
return tuple(_ops.GetTupleElement(out, i) for i in range(3))
# ?orgqr: product of elementary Householder reflectors:
cdef int lapack_sorgqr_workspace(int m, int n, int k):
cdef float work = 0
cdef int lwork = -1
cdef int info = 0
sorgqr(&m, &n, &k, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work) if info == 0 else -1
cdef void lapack_sorgqr(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int k = (<int32_t*>(data[3]))[0]
cdef int lwork = (<int32_t*>(data[4]))[0]
cdef const float* a_in = <float*>(data[5])
cdef float* tau = <float*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef int* info = <int*>(out[1])
cdef float* work = <float*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float))
for i in range(b):
sorgqr(&m, &n, &k, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += k
info += 1
register_cpu_custom_call_target(b"lapack_sorgqr", <void*>(lapack_sorgqr))
cdef int lapack_dorgqr_workspace(int m, int n, int k):
cdef double work = 0
cdef int lwork = -1
cdef int info = 0
dorgqr(&m, &n, &k, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work) if info == 0 else -1
cdef void lapack_dorgqr(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int k = (<int32_t*>(data[3]))[0]
cdef int lwork = (<int32_t*>(data[4]))[0]
cdef const double* a_in = <double*>(data[5])
cdef double* tau = <double*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef int* info = <int*>(out[1])
cdef double* work = <double*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double))
for i in range(b):
dorgqr(&m, &n, &k, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += k
info += 1
register_cpu_custom_call_target(b"lapack_dorgqr", <void*>(lapack_dorgqr))
cdef int lapack_cungqr_workspace(int m, int n, int k):
cdef float complex work = 0
cdef int lwork = -1
cdef int info = 0
cungqr(&m, &n, &k, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work.real) if info == 0 else -1
cdef void lapack_cungqr(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int k = (<int32_t*>(data[3]))[0]
cdef int lwork = (<int32_t*>(data[4]))[0]
cdef const float complex* a_in = <float complex*>(data[5])
cdef float complex* tau = <float complex*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef int* info = <int*>(out[1])
cdef float complex* work = <float complex*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(float complex))
for i in range(b):
cungqr(&m, &n, &k, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += k
info += 1
register_cpu_custom_call_target(b"lapack_cungqr", <void*>(lapack_cungqr))
cdef int lapack_zungqr_workspace(int m, int n, int k):
cdef double complex work = 0
cdef int lwork = -1
cdef int info = 0
zungqr(&m, &n, &k, NULL, &m, NULL, &work, &lwork, &info)
return <int>(work.real) if info == 0 else -1
cdef void lapack_zungqr(void* out_tuple, void** data) nogil:
cdef int b = (<int32_t*>(data[0]))[0]
cdef int m = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef int k = (<int32_t*>(data[3]))[0]
cdef int lwork = (<int32_t*>(data[4]))[0]
cdef const double complex* a_in = <double complex*>(data[5])
cdef double complex* tau = <double complex*>(data[6])
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef int* info = <int*>(out[1])
cdef double complex* work = <double complex*>(out[2])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(m) * <int64_t>(n) * sizeof(double complex))
for i in range(b):
zungqr(&m, &n, &k, a_out, &m, tau, work, &lwork, info)
a_out += m * n
tau += k
info += 1
register_cpu_custom_call_target(b"lapack_zungqr", <void*>(lapack_zungqr))
def orgqr(c, a, tau):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
tau_dims = c.get_shape(tau).dimensions()
assert tau_dims[:-1] == dims[:-2]
k = tau_dims[-1]
if dtype == np.float32:
fn = b"lapack_sorgqr"
lwork = lapack_sorgqr_workspace(m, n, k)
elif dtype == np.float64:
fn = b"lapack_dorgqr"
lwork = lapack_dorgqr_workspace(m, n, k)
elif dtype == np.complex64:
fn = b"lapack_cungqr"
lwork = lapack_cungqr_workspace(m, n, k)
elif dtype == np.complex128:
fn = b"lapack_zungqr"
lwork = lapack_zungqr_workspace(m, n, k)
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
out = _ops.CustomCallWithLayout(
c, fn,
operands=(
_constant_s32_scalar(c, b),
_constant_s32_scalar(c, m),
_constant_s32_scalar(c, n),
_constant_s32_scalar(c, k),
_constant_s32_scalar(c, lwork),
a,
tau,
),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(np.dtype(np.int32), batch_dims,
tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(dtype, (lwork,), (0,)),
)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(
dtype,
batch_dims + (m, n),
(num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))),
Shape.array_shape(
dtype,
batch_dims + (k,),
tuple(range(num_bd, -1, -1))),
))
return tuple(_ops.GetTupleElement(out, i) for i in range(2))
# ?potrf: Cholesky decomposition
cdef void lapack_spotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float* a_in = <float*>(data[3])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef float* a_out = <float*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(float))
for i in range(b):
spotrf(&uplo, &n, a_out, &n, info)
a_out += <int64_t>(n) * <int64_t>(n)
info += 1
register_cpu_custom_call_target(b"lapack_spotrf", <void*>(lapack_spotrf))
cdef void lapack_dpotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double* a_in = <double*>(data[3])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef double* a_out = <double*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(double))
for i in range(b):
dpotrf(&uplo, &n, a_out, &n, info)
a_out += <int64_t>(n) * <int64_t>(n)
info += 1
register_cpu_custom_call_target(b"lapack_dpotrf", <void*>(lapack_dpotrf))
cdef void lapack_cpotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const float complex* a_in = <float complex*>(data[3])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef float complex* a_out = <float complex*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(float complex))
for i in range(b):
cpotrf(&uplo, &n, a_out, &n, info)
a_out += <int64_t>(n) * <int64_t>(n)
info += 1
register_cpu_custom_call_target(b"lapack_cpotrf", <void*>(lapack_cpotrf))
cdef void lapack_zpotrf(void* out_tuple, void** data) nogil:
cdef int32_t lower = (<int32_t*>(data[0]))[0]
cdef int b = (<int32_t*>(data[1]))[0]
cdef int n = (<int32_t*>(data[2]))[0]
cdef const double complex* a_in = <double complex*>(data[3])
cdef char uplo = 'L' if lower else 'U'
cdef void** out = <void**>(out_tuple)
cdef double complex* a_out = <double complex*>(out[0])
cdef int* info = <int*>(out[1])
if a_out != a_in:
memcpy(a_out, a_in,
<int64_t>(b) * <int64_t>(n) * <int64_t>(n) * sizeof(double complex))
for i in range(b):
zpotrf(&uplo, &n, a_out, &n, info)
a_out += <int64_t>(n) * <int64_t>(n)
info += 1
register_cpu_custom_call_target(b"lapack_zpotrf", <void*>(lapack_zpotrf))
def potrf(c, a, lower=False):
c = _unpack_builder(c)
assert sizeof(int32_t) == sizeof(int)
a_shape = c.get_shape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
m, n = dims[-2:]
if m != n:
raise ValueError("potrf expects a square matrix, got {}".format(a_shape))
if dtype == np.float32:
fn = b"lapack_spotrf"
elif dtype == np.float64:
fn = b"lapack_dpotrf"
elif dtype == np.complex64:
fn = b"lapack_cpotrf"
elif dtype == np.complex128:
fn = b"lapack_zpotrf"
else:
raise NotImplementedError("Unsupported dtype {}".format(dtype))
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = _ops.CustomCallWithLayout(
c, fn,
operands=(_constant_s32_scalar(c, int(lower)),
_constant_s32_scalar(c, b), _constant_s32_scalar(c, n), a),
shape_with_layout=Shape.tuple_shape((
Shape.array_shape(dtype, dims, layout),
Shape.array_shape(
np.dtype(np.int32), batch_dims, tuple(range(num_bd - 1, -1, -1))),
)),
operand_shapes_with_layout=(
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(np.dtype(np.int32), (), ()),
Shape.array_shape(dtype, dims, layout),
))
return tuple(_ops.GetTupleElement(out, i) for i in range(2))
# ?gesdd: Singular value decomposition
cdef int gesdd_iwork_size(int64_t m, int64_t n) nogil:
# Avoid integer overflow; the LAPACK integer type is int32.
return min(_int32_max, 8 * min(m, n))
cdef int cgesdd_rwork_size(int64_t m, int64_t n, int compute_uv) nogil:
cdef int64_t mn = min(m, n)
if compute_uv == 0:
return 7 * mn
cdef int64_t mx = max(m, n)
# Avoid integer overflow; the LAPACK integer type is int32.
return min(_int32_max,
max(5 * mn * mn + 5 * mn, 2 * mx * mn + 2 * mn * mn + mn))
cdef char gesdd_jobz(bool_t job_opt_compute_uv,
bool_t job_opt_full_matrices) nogil:
# define appropriate job code
cdef char jobz = 'A'
if job_opt_compute_uv == 0:
jobz = 'N'
else:
if job_opt_full_matrices == 0:
jobz = 'S'
return jobz
cdef int sgesdd_work_size(int m, int n, bool_t job_opt_compute_uv,
bool_t job_opt_full_matrices):
cdef float work = 0
cdef int lwork = -1
cdef int info = 0
cdef int ldvt = min(m, n) if job_opt_full_matrices == 0 else n
cdef char jobz = gesdd_jobz(job_opt_compute_uv, job_opt_full_matrices)
sgesdd(&jobz, &m, &n, NULL, &m, NULL, NULL, &m, NULL, &ldvt, &work,