forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Compute gtsv2 buffer size ahead of time and pass in to kernel.
A user reported that with their Quadro M4000 GPU (Driver: 460.56) tridiagonal_solve was throwing an "unsupported operation" error. I improved the logging (also included in this patch) and tracked it down to: jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: third_party/py/jax/jaxlib/cusparse.cc:902: CUDA operation cudaMallocAsync(&buffer, bufferSize, stream) failed: operation not supported I had some challenges trying to figure out when async malloc was supported (it seems that for cards with compute <6 it fails) but have found an alternative approach where we compute the buffer size ahead of time and ask XLA to allocate. This is preferred for sure (although requires passing null pointers into cusparseSgtsv2_bufferSizeExt which seems to work today but I guess might change in future cuSPARSE releases).
- Loading branch information
1 parent
efd37e9
commit afa0d57
Showing
2 changed files
with
38 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters