forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgpu_solver.py
556 lines (494 loc) · 18.7 KB
/
gpu_solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from functools import partial
import importlib
import math
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
from .gpu_common_utils import GpuLibNotLinkedError
from jaxlib import xla_client
from .hlo_helpers import (
DimensionSize, ShapeTypePair, mk_result_types_and_shapes,
custom_call, ensure_hlo_s32, hlo_s32, dense_int_array)
try:
from .cuda import _blas as _cublas # pytype: disable=import-error
except ImportError:
for cuda_module_name in ["jax_cuda12_plugin"]:
try:
_cublas = importlib.import_module(f"{cuda_module_name}._blas")
except ImportError:
_cublas = None
else:
break
if _cublas:
for _name, _value in _cublas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
try:
_cusolver = importlib.import_module(
f"{cuda_module_name}._solver", package="jaxlib"
)
except ImportError:
_cusolver = None
else:
break
if _cusolver:
for _name, _value in _cusolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
try:
from .rocm import _blas as _hipblas # pytype: disable=import-error
for _name, _value in _hipblas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
_hipblas = None
try:
from .rocm import _solver as _hipsolver # pytype: disable=import-error
for _name, _value in _hipsolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
_hipsolver = None
def _real_type(dtype):
"""Returns the real equivalent of 'dtype'."""
return np.finfo(dtype).dtype
def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a):
"""LU decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = math.prod(batch_dims)
if not gpu_blas:
raise GpuLibNotLinkedError()
if batch > 1 and m == n and m // batch <= 128:
lwork, opaque = gpu_blas.build_getrf_batched_descriptor(
np.dtype(dtype), batch, m)
workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8))
kernel = f"{platform}blas_getrf_batched"
else:
lwork, opaque = gpu_solver.build_getrf_descriptor(
np.dtype(dtype), batch, m, n)
workspace = ir.RankedTensorType.get([lwork], a_type.element_type)
kernel = f"{platform}solver_getrf"
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
kernel,
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type),
ir.RankedTensorType.get(batch_dims, i32_type),
workspace,
],
operands=[a],
backend_config=opaque,
operand_layouts=[layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0}).results
return out[:3]
cuda_getrf = partial(_getrf_hlo, "cu", _cublas, _cusolver)
rocm_getrf = partial(_getrf_hlo, "hip", _hipblas, _hipsolver)
def _geqrf_hlo(platform, gpu_solver, dtype, a):
"""QR decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = math.prod(batch_dims)
lwork, opaque = gpu_solver.build_geqrf_descriptor(
np.dtype(dtype), batch, m, n)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
f"{platform}solver_geqrf",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a],
backend_config=opaque,
operand_layouts=[layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0}).results
return out[:3]
cuda_geqrf = partial(_geqrf_hlo, "cu", _cusolver)
rocm_geqrf = partial(_geqrf_hlo, "hip", _hipsolver)
def _geqrf_batched_hlo(platform, gpu_blas, dtype, a):
"""Batched QR decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = math.prod(batch_dims)
if not gpu_blas:
raise GpuLibNotLinkedError()
lwork, opaque = gpu_blas.build_geqrf_batched_descriptor(
np.dtype(dtype), batch, m, n)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
out = custom_call(
f"{platform}blas_geqrf_batched",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), a_type.element_type),
ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)),
ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)),
],
operands=[a],
backend_config=opaque,
operand_layouts=[layout],
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
[0],
[0],
],
operand_output_aliases={0: 0}
).results
return out[:2]
cuda_geqrf_batched = partial(_geqrf_batched_hlo, "cu", _cublas)
rocm_geqrf_batched = partial(_geqrf_batched_hlo, "hip", _hipblas)
def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
indices, indptr, b, tol, reorder):
"""Sparse solver via QR decomposition. CUDA only."""
b_type = ir.RankedTensorType(b.type)
data_type = ir.RankedTensorType(data.type)
n = b_type.shape[0]
nnz = data_type.shape[0]
opaque = gpu_solver.build_csrlsvqr_descriptor(
np.dtype(dtype), n, nnz, reorder, tol
)
out = custom_call(
f"{platform}solver_csrlsvqr", # call_target_name
result_types=[b.type],
operands=[data, indptr, indices, b],
backend_config=opaque, # backend_config
operand_layouts=[(0,), (0,), (0,), (0,)], # operand_layouts
result_layouts=[(0,)] # result_layouts
).results
return out
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)
def _orgqr_hlo(platform, gpu_solver, dtype, a, tau):
"""Product of elementary Householder reflections."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = math.prod(batch_dims)
tau_dims = ir.RankedTensorType(tau.type).shape
assert tau_dims[:-1] == dims[:-2]
k = tau_dims[-1]
lwork, opaque = gpu_solver.build_orgqr_descriptor(
np.dtype(dtype), batch, m, n, k)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
out = custom_call(
f"{platform}solver_orgqr",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a, tau],
backend_config=opaque,
operand_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
],
result_layouts=[
layout,
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0}).results
return out[:2]
cuda_orgqr = partial(_orgqr_hlo, "cu", _cusolver)
rocm_orgqr = partial(_orgqr_hlo, "hip", _hipsolver)
def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, *,
a_shape_vals: tuple[DimensionSize, ...], lower=False):
"""Symmetric (Hermitian) eigendecomposition."""
a_type = ir.RankedTensorType(a.type)
assert len(a_shape_vals) >= 2
m, n = a_shape_vals[-2:]
assert type(m) is int and type(n) is int and m == n, a_shape_vals
batch_dims_vals = a_shape_vals[:-2]
num_bd = len(batch_dims_vals)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
dynamic_batch_dims = any(type(d) != int for d in batch_dims_vals)
if dynamic_batch_dims:
batch_int = -1 # Signals to the kernel that the batch is an operand.
else:
batch_int = math.prod(batch_dims_vals)
if have_jacobi_solver and n <= 32 and not dynamic_batch_dims:
# We cannot use syevj for dynamic shapes because the workspace size
# depends on the batch size.
kernel = f"{platform}solver_syevj"
lwork, opaque = gpu_solver.build_syevj_descriptor(
np.dtype(dtype), lower, batch_int, n)
else:
kernel = f"{platform}solver_syevd"
lwork, opaque = gpu_solver.build_syevd_descriptor(
np.dtype(dtype), lower, batch_int, n)
assert lwork > 0
if ir.ComplexType.isinstance(a_type.element_type):
eigvals_type = ir.ComplexType(a_type.element_type).element_type
else:
eigvals_type = a_type.element_type
i32_type = ir.IntegerType.get_signless(32)
operands = [a]
operand_layouts = [layout]
if dynamic_batch_dims:
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
operands.append(batch_size_val)
operand_layouts.append(())
shape_type_pairs: Sequence[ShapeTypePair] = [
(a_shape_vals, a_type.element_type),
(batch_dims_vals + (n,), eigvals_type),
(batch_dims_vals, i32_type),
([lwork], a_type.element_type)]
result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs)
out = custom_call(
kernel,
result_types=result_types,
operands=operands,
backend_config=opaque,
operand_layouts=operand_layouts,
result_layouts=[
layout,
tuple(range(num_bd, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0},
result_shapes=result_shapes).results
return out[:3]
cuda_syevd = partial(_syevd_hlo, "cu", _cusolver, True)
rocm_syevd = partial(_syevd_hlo, "hip", _hipsolver, True)
def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
full_matrices=True, compute_uv=True):
"""Singular value decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = math.prod(batch_dims)
if ir.ComplexType.isinstance(a_type.element_type):
singular_vals_type = ir.ComplexType(a_type.element_type).element_type
else:
singular_vals_type = a_type.element_type
scalar_layout = tuple(range(num_bd - 1, -1, -1))
vector_layout = (num_bd,) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
# NVIDIA's batched Jacobi solver supports a maximum matrix size of 32x32, but
# the unbatched solver has no such limit. The unbatched solver appears to
# outperform gesvd for small-moderate matrices, e.g., see:
# https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9226-fast-singular-value-decomposition-on-gpus-v2.pdf
# slide 5.
if have_jacobi_solver and m <= 1024 and n <= 1024:
# The gesvdjbatched kernel doesn't support "econ" mode. We will use that
# kernel only if b > 1 and m <= 32 and n <= 32.
econ = not full_matrices and (b <= 1 or m > 32 or n > 32)
lwork, opaque = gpu_solver.build_gesvdj_descriptor(
np.dtype(dtype), b, m, n, compute_uv, 1 if econ else 0)
k = min(m, n)
matrix_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
_, s, u, v, info, _ = custom_call(
f"{platform}solver_gesvdj",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
ir.RankedTensorType.get(batch_dims + (m, k if econ else m),
a_type.element_type),
ir.RankedTensorType.get(batch_dims + (n, k if econ else n),
a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a],
backend_config=opaque,
operand_layouts=[matrix_layout],
result_layouts=[
matrix_layout,
vector_layout,
matrix_layout,
matrix_layout,
scalar_layout,
[0],
],
operand_output_aliases={0: 0}).results
vt = hlo.transpose(
v,
dense_int_array(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd))))
if np.issubdtype(dtype, np.complexfloating):
vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt)))
if not full_matrices and not econ:
u = hlo.slice(
u,
dense_int_array(np.zeros([len(dims)], np.int64)),
dense_int_array(np.array(batch_dims + (m, min(m, n)))),
dense_int_array(np.ones([len(dims)], np.int64)))
vt = hlo.slice(
vt,
dense_int_array(np.zeros([len(dims)], np.int64)),
dense_int_array(np.array(batch_dims + (min(m, n), n))),
dense_int_array(np.ones([len(dims)], np.int64)))
elif m < n:
lwork, opaque = gpu_solver.build_gesvd_descriptor(
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
k = n if full_matrices else m
matrix_layout = (num_bd + 1, num_bd) + tuple(range(num_bd - 1, -1, -1))
_, s, vt, u, info, _ = custom_call(
f"{platform}solver_gesvd",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
ir.RankedTensorType.get(batch_dims + (k, n), a_type.element_type),
ir.RankedTensorType.get(batch_dims + (m, m), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a],
backend_config=opaque,
operand_layouts=[matrix_layout],
result_layouts=[
matrix_layout,
vector_layout,
matrix_layout,
matrix_layout,
scalar_layout,
[0],
],
operand_output_aliases={0: 0}).results
else:
lwork, opaque = gpu_solver.build_gesvd_descriptor(
np.dtype(dtype), b, m, n, compute_uv, full_matrices)
k = m if full_matrices else n
matrix_layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
_, s, u, vt, info, _ = custom_call(
f"{platform}solver_gesvd",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (min(m, n),), singular_vals_type),
ir.RankedTensorType.get(batch_dims + (m, k), a_type.element_type),
ir.RankedTensorType.get(batch_dims + (n, n), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a],
backend_config=opaque,
operand_layouts=[matrix_layout],
result_layouts=[
matrix_layout,
vector_layout,
matrix_layout,
matrix_layout,
scalar_layout,
[0],
],
operand_output_aliases={0: 0}).results
return s, u, vt, info
cuda_gesvd = partial(_gesvd_hlo, "cu", _cusolver, True)
rocm_gesvd = partial(_gesvd_hlo, "hip", _hipsolver, False)
def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
"""sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n, (m, n)
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
b = 1
for d in batch_dims:
b *= d
lwork, opaque = gpu_solver.build_sytrd_descriptor(dtype, lower, b, n)
if np.issubdtype(dtype, np.floating):
diag_type = a_type.element_type
elif dtype == np.complex64:
diag_type = ir.F32Type.get()
elif dtype == np.complex128:
diag_type = ir.F64Type.get()
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
i32_type = ir.IntegerType.get_signless(32)
a, d, e, taus, info, _ = custom_call(
f"{platform}solver_sytrd",
result_types=[
a.type,
ir.RankedTensorType.get(batch_dims + (n,), diag_type),
ir.RankedTensorType.get(batch_dims + (n - 1,), diag_type),
ir.RankedTensorType.get(batch_dims + (n - 1,), a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
operands=[a],
backend_config=opaque,
operand_layouts=[layout],
result_layouts=[
layout,
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
(num_bd,) + tuple(range(num_bd - 1, -1, -1)),
tuple(range(num_bd - 1, -1, -1)),
[0],
],
operand_output_aliases={0: 0},
).results
# Workaround for NVIDIA partners bug #3865118: sytrd returns an incorrect "1"
# in the first element of the superdiagonal in the `a` matrix in the
# lower=False case. The correct result is returned in the `e` vector so we can
# simply copy it back to where it needs to be:
intattr = lambda xs: ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64))
intarrattr = lambda xs: dense_int_array(np.asarray(xs, np.int64))
if not lower and platform == "cu" and m > 1:
start = (0,) * len(batch_dims) + (0,)
end = batch_dims + (1,)
s = hlo.slice(
e, intarrattr(start), intarrattr(end), intarrattr([1] * len(start)))
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
s = hlo.broadcast_in_dim(s_type, s, intarrattr(range(len(dims) - 1)))
# The diagonals are always real; convert to complex if needed.
s = hlo.convert(
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)
offsets = tuple(hlo.constant(intattr(i))
for i in ((0,) * len(batch_dims) + (0, 1)))
a = hlo.dynamic_update_slice(a, s, offsets)
return a, d, e, taus, info
cuda_sytrd = partial(_sytrd_hlo, "cu", _cusolver)
rocm_sytrd = partial(_sytrd_hlo, "hip", _hipsolver)