Skip to content

Commit

Permalink
Compute gtsv2 buffer size ahead of time and pass in to kernel.
Browse files Browse the repository at this point in the history
A user reported that with their Quadro M4000 GPU (Driver: 460.56) tridiagonal_solve was throwing an "unsupported operation" error. I improved the logging (also included in this patch) and tracked it down to:

jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: third_party/py/jax/jaxlib/cusparse.cc:902: CUDA operation cudaMallocAsync(&buffer, bufferSize, stream) failed: operation not supported

I had some challenges trying to figure out when async malloc was supported (it seems that for cards with compute <6 it fails) but have found an alternative approach where we compute the buffer size ahead of time and ask XLA to allocate. This is preferred for sure (although requires passing null pointers into cusparseSgtsv2_bufferSizeExt which seems to work today but I guess might change in future cuSPARSE releases).
  • Loading branch information
tomhennigan committed Jul 15, 2021
1 parent efd37e9 commit afa0d57
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 30 deletions.
54 changes: 27 additions & 27 deletions jaxlib/cusparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,9 @@ py::bytes BuildGtsv2Descriptor(int m, int n, int ldb) {
return PackDescriptor(Gtsv2Descriptor{m, n, ldb});
}

template <typename T, typename F1, typename F2>
void gtsv2(F1 computeGtsv2BufSize, F2 computeGtsv2, cudaStream_t stream,
void** buffers, const char* opaque, std::size_t opaque_len) {
template <typename T, typename F>
void gtsv2(F computeGtsv2, cudaStream_t stream, void** buffers,
const char* opaque, std::size_t opaque_len) {
auto handle = SparseHandlePool::Borrow();

const Gtsv2Descriptor& descriptor =
Expand All @@ -857,6 +857,7 @@ void gtsv2(F1 computeGtsv2BufSize, F2 computeGtsv2, cudaStream_t stream,
const T* du = (const T*)(buffers[2]);
const T* B = (T*)(buffers[3]);
T* X = (T*)(buffers[4]);
void* buffer = buffers[5];

// The solution X is written in place to B. We need to therefore copy the
// contents of B into the output buffer X and pass that into the kernel as B.
Expand All @@ -870,39 +871,36 @@ void gtsv2(F1 computeGtsv2BufSize, F2 computeGtsv2, cudaStream_t stream,
cudaMemcpyAsync(X, B, B_bytes, cudaMemcpyDeviceToDevice, stream));
}

size_t bufferSize;
JAX_THROW_IF_ERROR(
computeGtsv2BufSize(handle.get(), m, n, dl, d, du, X, ldb, &bufferSize));

void* buffer;
#if CUDA_VERSION >= 11020
JAX_THROW_IF_ERROR(cudaMallocAsync(&buffer, bufferSize, stream));
#else
JAX_THROW_IF_ERROR(cudaMalloc(&buffer, bufferSize));
#endif // CUDA_VERSION >= 11020

auto computeStatus =
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer);

#if CUDA_VERSION >= 11020
JAX_THROW_IF_ERROR(cudaFreeAsync(buffer, stream));
#else
JAX_THROW_IF_ERROR(cudaFree(buffer));
#endif // CUDA_VERSION >= 11020

JAX_THROW_IF_ERROR(computeStatus);
computeGtsv2(handle.get(), m, n, dl, d, du, /*B=*/X, ldb, buffer));
}

void gtsv2_f32(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len) {
gtsv2<float>(cusparseSgtsv2_bufferSizeExt, cusparseSgtsv2, stream, buffers,
opaque, opaque_len);
gtsv2<float>(cusparseSgtsv2, stream, buffers, opaque, opaque_len);
}


void gtsv2_f64(cudaStream_t stream, void** buffers, const char* opaque,
std::size_t opaque_len) {
gtsv2<double>(cusparseDgtsv2_bufferSizeExt, cusparseDgtsv2, stream, buffers,
opaque, opaque_len);
gtsv2<double>(cusparseDgtsv2, stream, buffers, opaque, opaque_len);
}

template<typename F>
size_t Gtsv2BufferSize(F f, int m, int n, int ldb) {
auto handle = SparseHandlePool::Borrow();
size_t size;
JAX_THROW_IF_ERROR(f(handle.get(), m, n, /*dl=*/nullptr, /*d=*/nullptr,
/*du=*/nullptr, /*B=*/nullptr, ldb, &size));
return size;
}

size_t Gtsv2BufferSizeF32(int m, int n, int ldb) {
return Gtsv2BufferSize(cusparseSgtsv2_bufferSizeExt, m, n, ldb);
}

size_t Gtsv2BufferSizeF64(int m, int n, int ldb) {
return Gtsv2BufferSize(cusparseDgtsv2_bufferSizeExt, m, n, ldb);
}

py::dict Registrations() {
Expand Down Expand Up @@ -936,6 +934,8 @@ PYBIND11_MODULE(cusparse_kernels, m) {
m.def("build_coo_matvec_descriptor", &BuildCooMatvecDescriptor);
m.def("build_coo_matmat_descriptor", &BuildCooMatmatDescriptor);
#endif
m.def("gtsv2_f32_buffer_size", &Gtsv2BufferSizeF32);
m.def("gtsv2_f64_buffer_size", &Gtsv2BufferSizeF64);
m.def("build_gtsv2_descriptor", &BuildGtsv2Descriptor);
}

Expand Down
14 changes: 11 additions & 3 deletions jaxlib/cusparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,20 @@ def coo_matmat(c, data, row, col, B, *, shape, transpose=False, compute_dtype=No

def gtsv2(c, dl, d, du, B, *, m, n, ldb, t):
"""Calls `cusparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
f32 = (t == np.float32)
dl_shape, d_shape, du_shape, B_shape = map(c.get_shape, (dl, d, du, B))
return xla_client.ops.CustomCallWithLayout(
if f32:
buffer_size = cusparse_kernels.gtsv2_f32_buffer_size(m, n, ldb)
else:
buffer_size = cusparse_kernels.gtsv2_f64_buffer_size(m, n, ldb)
out = xla_client.ops.CustomCallWithLayout(
c,
b"cusparse_gtsv2_" + (b"f32" if (t == np.float32) else b"f64"),
b"cusparse_gtsv2_" + (b"f32" if f32 else b"f64"),
operands=(dl, d, du, B),
operand_shapes_with_layout=(dl_shape, d_shape, du_shape, B_shape),
shape_with_layout=B_shape,
shape_with_layout=_Shape.tuple_shape(
(_Shape.array_shape(np.dtype(t), (ldb, n), (1, 0)),
_Shape.array_shape(np.dtype(np.uint8), (buffer_size,), (0,)))),
opaque=cusparse_kernels.build_gtsv2_descriptor(m, n, ldb),
has_side_effect=False)
return _ops.GetTupleElement(out, 0)

0 comments on commit afa0d57

Please sign in to comment.