Skip to content

Commit

Permalink
Add lax.linalg.tridiagonal_solve(..), lowering to cusparse_gtsv2<T>()…
Browse files Browse the repository at this point in the history
… on GPU.

Fixes jax-ml#6830.
  • Loading branch information
tomhennigan committed Jun 2, 2021
1 parent 46cc654 commit ffac40a
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 37 deletions.
61 changes: 61 additions & 0 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools

import numpy as np

from jax._src.numpy import lax_numpy as jnp
Expand All @@ -35,6 +37,7 @@

from jax.lib import cuda_linalg
from jax.lib import cusolver
from jax.lib import cusparse
from jax.lib import rocsolver

from jax.lib import xla_client
Expand Down Expand Up @@ -1350,3 +1353,61 @@ def svd_batching_rule(batched_args, batch_dims, full_matrices, compute_uv):
if rocsolver is not None:
xla.backend_specific_translations['gpu'][svd_p] = partial(
_svd_cpu_gpu_translation_rule, rocsolver.gesvd)


tridiagonal_solve_p = Primitive('tridiagonal_solve')
tridiagonal_solve_p.multiple_results = False
tridiagonal_solve_p.def_impl(
functools.partial(xla.apply_primitive, tridiagonal_solve_p))
tridiagonal_solve_p.def_abstract_eval(lambda dl, d, du, b, *, m, n, ldb, t: b)
# TODO(tomhennigan): Consider AD rules using lax.custom_linear_solve?
if cusparse is not None:
xla.backend_specific_translations['gpu'][tridiagonal_solve_p] = cusparse.gtsv2


def tridiagonal_solve(dl, d, du, b):
r"""Computes the solution of a tridiagonal linear system.
This function computes the solution of a tridiagonal linear system::
.. math::
A . X = B
Args:
dl: The lower diagonal of A: ``dl[i] := A[i, i-1]`` for i in ``[0,m)``.
Note that ``dl[0] = 0``.
d: The middle diagnoal of A: ``d[i] := A[i, i]`` for i in ``[0,m)``.
du: The upper diagonal of A: ``du[i] := A[i, i+1]`` for i in ``[0,m)``.
Note that ``dl[m - 1] = 0``.
b: Right hand side matrix.
Returns:
Solution ``X`` of tridiagonal system.
"""
if dl.ndim != 1 or d.ndim != 1 or du.ndim != 1:
raise ValueError('dl, d and du must be vectors')

if dl.shape != d.shape or d.shape != du.shape:
raise ValueError(
f'dl={dl.shape}, d={d.shape} and du={du.shape} must all be `[m]`')

if b.ndim != 2:
raise ValueError(f'b={b.shape} must be a matrix')

m, = dl.shape
if m < 3:
raise ValueError(f'm ({m}) must be >= 3')

ldb, n = b.shape
if ldb < max(1, m):
raise ValueError(f'Leading dimension of b={ldb} must be ≥ max(1, {m})')

if dl.dtype != d.dtype or d.dtype != du.dtype or du.dtype != b.dtype:
raise ValueError(f'dl={dl.dtype}, d={d.dtype}, du={du.dtype} and '
f'b={b.dtype} must be the same dtype,')

t = dl.dtype
if t not in (np.float32, np.float64):
raise ValueError(f'Only f32/f64 are supported, got {t}')

return tridiagonal_solve_p.bind(dl, d, du, b, m=m, n=n, ldb=ldb, t=t)
2 changes: 1 addition & 1 deletion jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params)
ans = rule(c, axis_env, extend_name_stack(prim.name), avals, backend,
*xla_args, **params)
else:
raise NotImplementedError(f"XLA translation rule for {prim} not found")
raise NotImplementedError(f"XLA translation rule for {prim!r} on platform {platform!r} not found")
assert isinstance(ans, xe.XlaOp)
c.clear_op_metadata()
try:
Expand Down
2 changes: 2 additions & 0 deletions jax/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@
svd_p,
triangular_solve,
triangular_solve_p,
tridiagonal_solve,
tridiagonal_solve_p,
)
131 changes: 96 additions & 35 deletions jaxlib/cusparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,40 +36,26 @@ limitations under the License.
#include "include/pybind11/stl.h"

// Some functionality defined here is only available in CUSPARSE 11.3 or newer.
#define JAX_ENABLE_CUSPARSE (CUSPARSE_VERSION >= 11300)
#define JAX_CUSPARSE_11030 (CUSPARSE_VERSION >= 11300)

namespace jax {
namespace {

namespace py = pybind11;

void ThrowIfErrorStatus(cusparseStatus_t status) {
switch (status) {
case CUSPARSE_STATUS_SUCCESS:
return;
case CUSPARSE_STATUS_NOT_INITIALIZED:
throw std::runtime_error("cuSparse has not been initialized");
case CUSPARSE_STATUS_ALLOC_FAILED:
throw std::runtime_error("cuSparse allocation failure");
case CUSPARSE_STATUS_INVALID_VALUE:
throw std::runtime_error("cuSparse invalid value error");
case CUSPARSE_STATUS_ARCH_MISMATCH:
throw std::runtime_error("cuSparse architecture mismatch");
case CUSPARSE_STATUS_MAPPING_ERROR:
throw std::runtime_error("cuSparse mapping error");
case CUSPARSE_STATUS_EXECUTION_FAILED:
throw std::runtime_error("cuSparse execution failed");
case CUSPARSE_STATUS_INTERNAL_ERROR:
throw std::runtime_error("cuSparse internal error");
case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
throw std::runtime_error("cuSparse matrix type not supported error");
case CUSPARSE_STATUS_ZERO_PIVOT:
throw std::runtime_error("cuSparse zero pivot error");
default:
throw std::runtime_error("Unknown cuSparse error");
if (status != CUSPARSE_STATUS_SUCCESS) {
throw std::runtime_error(cusparseGetErrorString(status));
}
}

void ThrowIfErrorStatus(cudaError_t error) {
if (error != cudaSuccess) {
throw std::runtime_error(cudaGetErrorString(error));
}
}


union CudaConst {
int8_t i8[2];
int16_t i16[2];
Expand All @@ -93,7 +79,7 @@ CudaConst CudaOne(cudaDataType type) {
CudaConst c;
std::memset(&c, 0, sizeof(c));
switch (type) {
#if JAX_ENABLE_CUSPARSE
#if JAX_CUSPARSE_11030
// TODO(jakevdp): 4I/4U here might break on big endian platforms.
case CUDA_R_4I:
case CUDA_C_4I:
Expand All @@ -102,15 +88,15 @@ CudaConst CudaOne(cudaDataType type) {
case CUDA_C_8I:
c.i8[0] = 1;
break;
#if JAX_ENABLE_CUSPARSE
#if JAX_CUSPARSE_11030
case CUDA_R_4U:
case CUDA_C_4U:
#endif
case CUDA_R_8U:
case CUDA_C_8U:
c.u8[0] = 1;
break;
#if JAX_ENABLE_CUSPARSE
#if JAX_CUSPARSE_11030
case CUDA_R_16I:
case CUDA_C_16I:
c.i16[0] = 1;
Expand All @@ -128,7 +114,7 @@ CudaConst CudaOne(cudaDataType type) {
case CUDA_C_32U:
c.u32[0] = 1;
break;
#if JAX_ENABLE_CUSPARSE
#if JAX_CUSPARSE_11030
case CUDA_R_64I:
case CUDA_C_64I:
c.i64[0] = 1;
Expand All @@ -143,7 +129,7 @@ CudaConst CudaOne(cudaDataType type) {
case CUDA_C_16F:
c.u16[0] = 0b11110000000000; // 1.0 in little-endian float16
break;
#if JAX_ENABLE_CUSPARSE
#if JAX_CUSPARSE_11030
case CUDA_R_16BF:
case CUDA_C_16BF:
c.u16[0] = 0b11111110000000; // 1.0 in little-endian bfloat16
Expand Down Expand Up @@ -204,7 +190,7 @@ cudaDataType DtypeToCudaDataType(const py::dtype& np_type) {
{{'c', 16}, CUDA_C_64F}, {{'i', 1}, CUDA_R_8I},
{{'u', 1}, CUDA_R_8U}, {{'i', 4}, CUDA_R_32I},
{{'u', 4}, CUDA_R_32U},
#if JAX_ENABLE_CUSPARSE
#if JAX_CUSPARSE_11030
{{'V', 2}, CUDA_R_16BF},
#endif
});
Expand Down Expand Up @@ -255,7 +241,7 @@ DenseVecDescriptor BuildDenseVecDescriptor(const py::dtype& data_dtype,
return DenseVecDescriptor{value_type, size};
}

#if JAX_ENABLE_CUSPARSE
#if JAX_CUSPARSE_11030
// CsrToDense: Convert CSR matrix to dense matrix

// Returns the descriptor for a Sparse matrix.
Expand Down Expand Up @@ -858,12 +844,83 @@ void CooMatmat(cudaStream_t stream, void** buffers, const char* opaque,
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_b));
ThrowIfErrorStatus(cusparseDestroyDnMat(mat_c));
}
#endif // if JAX_CUSPARSE_11030

#endif
struct Gtsv2Descriptor {
int m, n, ldb;
};

py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
}

template <typename T, typename F1, typename F2>
void gtsv2(F1 computeGtsv2BufSize, F2 computeGtsv2, cudaStream_t stream,
void** buffers, const char* opaque, std::size_t opaque_len) {
auto handle = SparseHandlePool::Borrow();

const Gtsv2Descriptor& descriptor =
*UnpackDescriptor<Gtsv2Descriptor>(opaque, opaque_len);
int m = descriptor.m;
int n = descriptor.n;
int ldb = descriptor.ldb;

const T* dl = (const T*)(buffers[0]);
const T* d = (const T*)(buffers[1]);
const T* du = (const T*)(buffers[2]);
const T* B = (T*)(buffers[3]);
T* X = (T*)(buffers[4]);

// The solution X is written in place to B. We need to therefore copy the
// contents of B into the output buffer X and pass that into the kernel as B.
// Once copy insertion is supported for custom call aliasing, we could alias B
// with X and avoid the copy, the code below is written defensively assuming B
// and X might alias, but today we know they will not.
// TODO(b/182906199): Update the comment here once copy insertion is WAI.
if (X != B) {
size_t B_bytes = ldb * n * sizeof(T);
ThrowIfErrorStatus(
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream));
}

size_t bufferSize;
ThrowIfErrorStatus(
computeGtsv2BufSize(handle.get(), m, n, dl, d, du, X, ldb, &bufferSize));

void* buffer;
#if CUDA_VERSION >= 11020
ThrowIfErrorStatus(cudaMallocAsync(&buffer, bufferSize, stream));
#else
ThrowIfErrorStatus(cudaMalloc(&buffer, bufferSize));
#endif // CUDA_VERSION >= 11020

auto computeStatus =
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer);

#if CUDA_VERSION >= 11020
ThrowIfErrorStatus(cudaFreeAsync(buffer, stream));
#else
ThrowIfErrorStatus(cudaFree(buffer));
#endif // CUDA_VERSION >= 11020

ThrowIfErrorStatus(computeStatus);
}

void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len) {
gtsv2<float>(cusparseSgtsv2_bufferSizeExt, cusparseSgtsv2, stream, buffers,
opaque, opaque_len);
}

void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len) {
gtsv2<double>(cusparseDgtsv2_bufferSizeExt, cusparseDgtsv2, stream, buffers,
opaque, opaque_len);
}

py::dict Registrations() {
py::dict dict;
#if JAX_ENABLE_CUSPARSE
#if JAX_CUSPARSE_11030
dict["cusparse_csr_todense"] = EncapsulateFunction(CsrToDense);
dict["cusparse_csr_fromdense"] = EncapsulateFunction(CsrFromDense);
dict["cusparse_csr_matvec"] = EncapsulateFunction(CsrMatvec);
Expand All @@ -873,13 +930,16 @@ py::dict Registrations() {
dict["cusparse_coo_matvec"] = EncapsulateFunction(CooMatvec);
dict["cusparse_coo_matmat"] = EncapsulateFunction(CooMatmat);
#endif
dict["cusparse_gtsv2_f32"] = EncapsulateFunction(gtsv2_f32);
dict["cusparse_gtsv2_f64"] = EncapsulateFunction(gtsv2_f64);
// TODO(tomhennigan): Add support for gtsv2 complex 32/64.
return dict;
}

PYBIND11_MODULE(cusparse_kernels, m) {
m.attr("cusparse_supported") = py::bool_(JAX_ENABLE_CUSPARSE);
m.attr("cusparse_supported") = py::bool_(JAX_CUSPARSE_11030);
m.def("registrations", &Registrations);
#if JAX_ENABLE_CUSPARSE
#if JAX_CUSPARSE_11030
m.def("build_csr_todense_descriptor", &BuildCsrToDenseDescriptor);
m.def("build_csr_fromdense_descriptor", &BuildCsrFromDenseDescriptor);
m.def("build_csr_matvec_descriptor", &BuildCsrMatvecDescriptor);
Expand All @@ -889,6 +949,7 @@ PYBIND11_MODULE(cusparse_kernels, m) {
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
#endif
m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor);
}

} // namespace
Expand Down
13 changes: 13 additions & 0 deletions jaxlib/cusparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,16 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No
opaque=opaque,
)
return _ops.GetTupleElement(out, 0)


def gtsv2(c, dl, d, du, B, *, m, n, ldb, t):
"""Calls `cusparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
dl_shape, d_shape, du_shape, B_shape = map(c.get_shape, (dl, d, du, B))
return xla_client.ops.CustomCallWithLayout(
c,
b"cusparse_gtsv2_" + (b"f32" if (t == np.float32) else b"f64"),
operands=(dl, d, du, B),
operand_shapes_with_layout=(dl_shape, d_shape, du_shape, B_shape),
shape_with_layout=B_shape,
opaque=cusparse_kernels.build_gtsv2_descriptor(m, n, ldb),
has_side_effect=False)
17 changes: 16 additions & 1 deletion tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,7 +1404,7 @@ def expm(x):
jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol,
rtol=tol)

class EighTridiagonalTest(jtu.JaxTestCase):
class LaxLinalgTest(jtu.JaxTestCase):

def run_test(self, alpha, beta):
n = alpha.shape[-1]
Expand Down Expand Up @@ -1467,6 +1467,21 @@ def testSelect(self, dtype):
self.assertAllClose(
eigvals_all[first:(last + 1)], eigvals_index, atol=atol)

@parameterized.parameters(np.float32, np.float64)
def test_tridiagonal_solve(self, dtype):
if jtu.device_under_test() != "gpu":
self.skipTest("Only supported on GPU")

dl = np.array([0.0, 1.0, 2.0], dtype=dtype)
d = np.ones(3, dtype=dtype)
du = np.array([1.0, 2.0, 0.0], dtype=dtype)
m = 3
B = np.ones([m, 1], dtype=dtype)
X = lax.linalg.tridiagonal_solve(dl, d, du, B)
A = np.eye(3, dtype=dtype)
A[[1, 2], [0, 1]] = dl[1:]
A[[0, 1], [1, 2]] = du[:-1]
np.testing.assert_allclose(A @ X, B)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit ffac40a

Please sign in to comment.