Skip to content

Commit

Permalink
Fix solver interfaces to use executor in cache (#809)
Browse files Browse the repository at this point in the history
* Fix solver interfaces to use executor in cache
* Add recursive mutex around cache lookup
  • Loading branch information
cliffburdick authored Dec 4, 2024
1 parent 0e5c634 commit 76623d6
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 61 deletions.
2 changes: 1 addition & 1 deletion examples/fft_conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
{
MATX_ENTER_HANDLER();
using complex = cuda::std::complex<float>;
cudaExecutor exec{};

index_t signal_size = 1ULL << 16;
index_t filter_size = 16;
Expand All @@ -87,6 +86,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
cudaExecutor exec{stream};

// Create time domain buffers
auto sig_time = make_tensor<complex>({batches, signal_size});
Expand Down
7 changes: 7 additions & 0 deletions include/matx/core/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <functional>
#include <optional>
#include <any>
#include <shared_mutex>
#include <unordered_map>
#include <cuda/atomic>

Expand All @@ -50,6 +51,7 @@ using CacheId = uint64_t;
__attribute__ ((visibility ("default")))
#endif
inline cuda::std::atomic<CacheId> CacheIdCounter{0};
inline std::recursive_mutex cache_mtx; ///< Mutex protecting updates from map

template<typename CacheType>
__attribute__ ((visibility ("default")))
Expand Down Expand Up @@ -83,6 +85,8 @@ class matxCache_t {
*/
template <typename CacheType>
void Clear(const CacheId &id) {
[[maybe_unused]] std::lock_guard<std::recursive_mutex> lock(cache_mtx);

auto el = cache.find(id);
MATX_ASSERT_STR(el != cache.end(), matxInvalidType, "Cache type not found");

Expand All @@ -91,6 +95,9 @@ class matxCache_t {

template <typename CacheType, typename InParams, typename MakeFun, typename ExecFun>
void LookupAndExec(const CacheId &id, const InParams &params, const MakeFun &mfun, const ExecFun &efun) {
// This mutex should eventually be finer-grained so each transform doesn't get blocked by others
[[maybe_unused]] std::lock_guard<std::recursive_mutex> lock(cache_mtx);

// Create named cache if it doesn't exist
auto el = cache.find(id);
if (el == cache.end()) {
Expand Down
28 changes: 19 additions & 9 deletions include/matx/transforms/chol/chol_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct DnCholCUDAParams_t {
size_t batch_size;
cublasFillMode_t uplo;
MatXDataType_t dtype;
cudaExecutor exec;
};

template <typename OutputTensor, typename ATensor>
Expand Down Expand Up @@ -89,8 +90,9 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t {
* Use upper or lower triangle for computation
*
*/
matxDnCholCUDAPlan_t(const ATensor &a,
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER)
matxDnCholCUDAPlan_t( const ATensor &a,
const cudaExecutor &exec,
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

Expand All @@ -101,9 +103,10 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t {
MATX_STATIC_ASSERT_STR(!is_half_v<T1>, matxInvalidType, "Cholesky solver does not support half precision");
MATX_STATIC_ASSERT_STR((std::is_same_v<T1, typename OutTensor_t::value_type>), matxInavlidType, "Input and Output types must match");

params = GetCholParams(a, uplo);
params = GetCholParams(a, uplo, exec);

this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size, false);
this->AllocateWorkspace(params.batch_size, false, exec);
}

void GetWorkspaceSize() override
Expand All @@ -117,13 +120,15 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t {
}

static DnCholCUDAParams_t GetCholParams(const ATensor &a,
cublasFillMode_t uplo)
cublasFillMode_t uplo,
const cudaExecutor &exec)
{
DnCholCUDAParams_t params;
params.batch_size = GetNumBatches(a);
params.n = a.Size(RANK - 1);
params.A = a.Data();
params.uplo = uplo;
params.exec = exec;
params.dtype = TypeToInt<T1>();

return params;
Expand Down Expand Up @@ -201,7 +206,9 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t {
struct DnCholCUDAParamsKeyHash {
std::size_t operator()(const DnCholCUDAParams_t &k) const noexcept
{
return (std::hash<uint64_t>()(k.n)) + (std::hash<uint64_t>()(k.batch_size));
return (std::hash<uint64_t>()(k.n)) +
(std::hash<uint64_t>()(k.batch_size)) +
(std::hash<uint64_t>()((uint64_t)(k.exec.getStream())));
}
};

Expand All @@ -213,7 +220,10 @@ struct DnCholCUDAParamsKeyEq {
bool operator()(const DnCholCUDAParams_t &l, const DnCholCUDAParams_t &t) const
noexcept
{
return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype;
return l.n == t.n &&
l.batch_size == t.batch_size &&
l.dtype == t.dtype &&
l.exec.getStream() == t.exec.getStream();
}
};

Expand Down Expand Up @@ -290,14 +300,14 @@ void chol_impl(OutputTensor &&out, const ATensor &a,
cublasFillMode_t uplo_cusolver = (uplo == SolverFillMode::UPPER)? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;

// Get parameters required by these tensors
auto params = detail::matxDnCholCUDAPlan_t<OutputTensor, decltype(tmp_out)>::GetCholParams(tmp_out, uplo_cusolver);
auto params = detail::matxDnCholCUDAPlan_t<OutputTensor, decltype(tmp_out)>::GetCholParams(tmp_out, uplo_cusolver, exec);

using cache_val_type = detail::matxDnCholCUDAPlan_t<OutputTensor, decltype(tmp_out)>;
detail::GetCache().LookupAndExec<detail::chol_cuda_cache_t>(
detail::GetCacheIdFromType<detail::chol_cuda_cache_t>(),
params,
[&]() {
return std::make_shared<cache_val_type>(tmp_out, uplo_cusolver);
return std::make_shared<cache_val_type>(tmp_out, exec, uplo_cusolver);
},
[&](std::shared_ptr<cache_val_type> ctype) {
ctype->Exec(tmp_out, tmp_out, exec, uplo_cusolver);
Expand Down
23 changes: 14 additions & 9 deletions include/matx/transforms/eig/eig_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct DnEigCUDAParams_t {
void *W;
size_t batch_size;
MatXDataType_t dtype;
cudaExecutor exec;
};

template <typename OutputTensor, typename WTensor, typename ATensor>
Expand Down Expand Up @@ -98,6 +99,7 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
*/
matxDnEigCUDAPlan_t(WTensor &w,
const ATensor &a,
const cudaExecutor &exec,
cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR,
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER)
{
Expand All @@ -113,12 +115,12 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
MATX_STATIC_ASSERT_STR(!is_complex_v<T2>, matxInvalidType, "W type must be real");
MATX_STATIC_ASSERT_STR((std::is_same_v<typename inner_op_type_t<T1>::type, T2>), matxInvalidType, "Out and W inner types must match");

params = GetEigParams(w, a, jobz, uplo);
params = GetEigParams(w, a, jobz, uplo, exec);
this->GetWorkspaceSize();
#if CUSOLVER_VERSION > 11701 || (CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >=2)
this->AllocateWorkspace(params.batch_size, true);
#if CUSOLVER_VERSION > 11701 || (CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >= 2)
this->AllocateWorkspace(params.batch_size, true, exec);
#else
this->AllocateWorkspace(params.batch_size, false);
this->AllocateWorkspace(params.batch_size, false, exec);
#endif
}

Expand Down Expand Up @@ -147,7 +149,8 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
static DnEigCUDAParams_t GetEigParams(WTensor &w,
const ATensor &a,
cusolverEigMode_t jobz,
cublasFillMode_t uplo)
cublasFillMode_t uplo,
const cudaExecutor &exec)
{
DnEigCUDAParams_t params;
params.batch_size = GetNumBatches(a);
Expand All @@ -156,6 +159,8 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
params.W = w.Data();
params.jobz = jobz;
params.uplo = uplo;
params.exec = exec;

params.dtype = TypeToInt<T1>();

return params;
Expand Down Expand Up @@ -258,7 +263,7 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t {
struct DnEigCUDAParamsKeyHash {
std::size_t operator()(const DnEigCUDAParams_t &k) const noexcept
{
return (std::hash<uint64_t>()(k.n)) + (std::hash<uint64_t>()(k.batch_size));
return (std::hash<uint64_t>()(k.n)) + (std::hash<uint64_t>()(k.batch_size)) + (std::hash<uint64_t>()((uint64_t)(k.exec.getStream())));
}
};

Expand All @@ -269,7 +274,7 @@ struct DnEigCUDAParamsKeyHash {
struct DnEigCUDAParamsKeyEq {
bool operator()(const DnEigCUDAParams_t &l, const DnEigCUDAParams_t &t) const noexcept
{
return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype;
return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype && l.exec.getStream() == t.exec.getStream();
}
};

Expand Down Expand Up @@ -339,15 +344,15 @@ void eig_impl(OutputTensor &&out, WTensor &&w,

// Get parameters required by these tensors
auto params = detail::matxDnEigCUDAPlan_t<OutputTensor, decltype(w_new), decltype(a_new)>::
GetEigParams(w_new, tv, jobz_cusolver, uplo_cusolver);
GetEigParams(w_new, tv, jobz_cusolver, uplo_cusolver, exec);

// Get cache or new eigen plan if it doesn't exist
using cache_val_type = detail::matxDnEigCUDAPlan_t<OutputTensor, decltype(w_new), decltype(a_new)>;
detail::GetCache().LookupAndExec<detail::eig_cuda_cache_t>(
detail::GetCacheIdFromType<detail::eig_cuda_cache_t>(),
params,
[&]() {
return std::make_shared<cache_val_type>(w_new, tv, jobz_cusolver, uplo_cusolver);
return std::make_shared<cache_val_type>(w_new, tv, exec, jobz_cusolver, uplo_cusolver);
},
[&](std::shared_ptr<cache_val_type> ctype) {
ctype->Exec(tv, w_new, tv, exec, jobz_cusolver, uplo_cusolver);
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/eig/eig_lapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class matxDnEigHostPlan_t : matxDnHostSolver_t<typename ATensor::value_type> {

params = GetEigParams(w, a, jobz, uplo);
this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size, false);
this->AllocateWorkspace(params.batch_size);
}

void GetWorkspaceSize() override
Expand Down
21 changes: 12 additions & 9 deletions include/matx/transforms/lu/lu_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ struct DnLUCUDAParams_t {
void *piv;
size_t batch_size;
MatXDataType_t dtype;
cudaExecutor exec;
};

template <typename OutputTensor, typename PivotTensor, typename ATensor>
Expand Down Expand Up @@ -91,7 +92,8 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t {
*
*/
matxDnLUCUDAPlan_t(PivotTensor &piv,
const ATensor &a)
const ATensor &a,
const cudaExecutor &exec)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)

Expand All @@ -104,9 +106,9 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t {
MATX_STATIC_ASSERT_STR((std::is_same_v<T1, typename OutTensor_t::value_type>), matxInavlidType, "Input and Output types must match");
MATX_STATIC_ASSERT_STR((std::is_same_v<T2, int64_t>), matxInavlidType, "Pivot tensor type must be int64_t");

params = GetLUParams(piv, a);
params = GetLUParams(piv, a, exec);
this->GetWorkspaceSize();
this->AllocateWorkspace(params.batch_size, false);
this->AllocateWorkspace(params.batch_size, false, exec);
}

void GetWorkspaceSize() override
Expand All @@ -120,7 +122,8 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t {
}

static DnLUCUDAParams_t GetLUParams(PivotTensor &piv,
const ATensor &a) noexcept
const ATensor &a,
const cudaExecutor &exec) noexcept
{
DnLUCUDAParams_t params;
params.batch_size = GetNumBatches(a);
Expand All @@ -129,7 +132,7 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t {
params.A = a.Data();
params.piv = piv.Data();
params.dtype = TypeToInt<T1>();

params.exec = exec;
return params;
}

Expand Down Expand Up @@ -212,7 +215,7 @@ struct DnLUCUDAParamsKeyHash {
std::size_t operator()(const DnLUCUDAParams_t &k) const noexcept
{
return (std::hash<uint64_t>()(k.m)) + (std::hash<uint64_t>()(k.n)) +
(std::hash<uint64_t>()(k.batch_size));
(std::hash<uint64_t>()(k.batch_size)) + (std::hash<uint64_t>()((uint64_t)(k.exec.getStream())));
}
};

Expand All @@ -223,7 +226,7 @@ struct DnLUCUDAParamsKeyEq {
bool operator()(const DnLUCUDAParams_t &l, const DnLUCUDAParams_t &t) const noexcept
{
return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size &&
l.dtype == t.dtype;
l.dtype == t.dtype && l.exec.getStream() == t.exec.getStream();
}
};

Expand Down Expand Up @@ -284,15 +287,15 @@ void lu_impl(OutputTensor &&out, PivotTensor &&piv,
auto tvt = tv.PermuteMatrix();

// Get parameters required by these tensors
auto params = detail::matxDnLUCUDAPlan_t<OutputTensor, decltype(piv_new), decltype(a_new)>::GetLUParams(piv_new, tvt);
auto params = detail::matxDnLUCUDAPlan_t<OutputTensor, decltype(piv_new), decltype(a_new)>::GetLUParams(piv_new, tvt, exec);

// Get cache or new LU plan if it doesn't exist
using cache_val_type = detail::matxDnLUCUDAPlan_t<OutputTensor, decltype(piv_new), decltype(a_new)>;
detail::GetCache().LookupAndExec<detail::lu_cuda_cache_t>(
detail::GetCacheIdFromType<detail::lu_cuda_cache_t>(),
params,
[&]() {
return std::make_shared<cache_val_type>(piv_new, tvt);
return std::make_shared<cache_val_type>(piv_new, tvt, exec);
},
[&](std::shared_ptr<cache_val_type> ctype) {
ctype->Exec(tvt, piv_new, tvt, exec);
Expand Down
Loading

0 comments on commit 76623d6

Please sign in to comment.