Skip to content

Commit

Permalink
Avoid sharing handles across streams.
Browse files Browse the repository at this point in the history
When running across 8xV100 GPUs we observed the following error:

    libc++abi: terminating with uncaught exception of type std::runtime_error: third_party/py/jax/jaxlib/cusolver.cc:171: operation cusolverDnSpotrf(handle.get(), d.uplo, d.n, a, d.n, static_cast<float*>(workspace), d.lwork, info) failed: cuSolver execution failed

I cannot find documentation to this effect, but I believe that it is unsafe to share cuSolver handles across streams, since keeping the handle pool stream local does solve the issue.
  • Loading branch information
tomhennigan committed Jul 16, 2021
1 parent b744a84 commit afbd831
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 22 deletions.
8 changes: 4 additions & 4 deletions jaxlib/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ template <>
BlasHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
cublasHandle_t handle;
if (pool->handles_.empty()) {
if (pool->handles_[stream].empty()) {
JAX_THROW_IF_ERROR(cublasCreate(&handle));
} else {
handle = pool->handles_.back();
pool->handles_.pop_back();
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_THROW_IF_ERROR(cublasSetStream(handle, stream));
}
return Handle(pool, handle);
return Handle(pool, handle, stream);
}

// Set of types known to Cusolver.
Expand Down
8 changes: 4 additions & 4 deletions jaxlib/cusolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ template <>
SolverHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
cusolverDnHandle_t handle;
if (pool->handles_.empty()) {
if (pool->handles_[stream].empty()) {
JAX_THROW_IF_ERROR(cusolverDnCreate(&handle));
} else {
handle = pool->handles_.back();
pool->handles_.pop_back();
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_THROW_IF_ERROR(cusolverDnSetStream(handle, stream));
}
return Handle(pool, handle);
return Handle(pool, handle, stream);
}

// Set of types known to Cusolver.
Expand Down
8 changes: 4 additions & 4 deletions jaxlib/cusparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,16 @@ template <>
SparseHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
cusparseHandle_t handle;
if (pool->handles_.empty()) {
if (pool->handles_[stream].empty()) {
JAX_THROW_IF_ERROR(cusparseCreate(&handle));
} else {
handle = pool->handles_.back();
pool->handles_.pop_back();
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_THROW_IF_ERROR(cusparseSetStream(handle, stream));
}
return Handle(pool, handle);
return Handle(pool, handle, stream);
}

cusparseIndexType_t DtypeToCuSparseIndexType(const py::dtype& np_type) {
Expand Down
17 changes: 10 additions & 7 deletions jaxlib/handle_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class HandlePool {
Handle() = default;
~Handle() {
if (pool_) {
pool_->Return(handle_);
pool_->Return(handle_, stream_);
}
}

Expand All @@ -62,10 +62,12 @@ class HandlePool {

private:
friend class HandlePool<HandleType, StreamType>;
Handle(HandlePool<HandleType, StreamType>* pool, HandleType handle)
: pool_(pool), handle_(handle) {}
Handle(HandlePool<HandleType, StreamType>* pool, HandleType handle,
StreamType stream)
: pool_(pool), handle_(handle), stream_(stream) {}
HandlePool<HandleType, StreamType>* pool_ = nullptr;
HandleType handle_ = nullptr;
StreamType stream_ = nullptr;
};

// Borrows a handle from the pool. If 'stream' is non-null, sets the stream
Expand All @@ -75,10 +77,10 @@ class HandlePool {
private:
static HandlePool<HandleType, StreamType>* Instance();

void Return(HandleType handle);
void Return(HandleType handle, StreamType stream);

absl::Mutex mu_;
std::vector<HandleType> handles_ ABSL_GUARDED_BY(mu_);
std::map<StreamType, std::vector<HandleType>> handles_ ABSL_GUARDED_BY(mu_);
};

template <typename HandleType, typename StreamType>
Expand All @@ -89,9 +91,10 @@ HandlePool<HandleType, StreamType>::Instance() {
}

template <typename HandleType, typename StreamType>
void HandlePool<HandleType, StreamType>::Return(HandleType handle) {
void HandlePool<HandleType, StreamType>::Return(HandleType handle,
StreamType stream) {
absl::MutexLock lock(&mu_);
handles_.push_back(handle);
handles_[stream].push_back(handle);
}

// template <typename HandleType, typename StreamType>
Expand Down
6 changes: 3 additions & 3 deletions jaxlib/rocblas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ template <>
rocBlasHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
rocblas_handle handle;
if (pool->handles_.empty()) {
if (pool->handles_[stream].empty()) {
ThrowIfErrorStatus(rocblas_create_handle(&handle));
} else {
handle = pool->handles_.back();
pool->handles_.pop_back();
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
ThrowIfErrorStatus(rocblas_set_stream(handle, stream));
Expand Down

0 comments on commit afbd831

Please sign in to comment.