Skip to content

Commit

Permalink
Merge pull request #13806 from iyamazaki/tacho-variant3
Browse files Browse the repository at this point in the history
Tacho : add LevelSet scheduling with Serial build
  • Loading branch information
iyamazaki authored Feb 14, 2025
2 parents 51d04bb + 1b6f124 commit 9edbfdd
Show file tree
Hide file tree
Showing 9 changed files with 328 additions and 22 deletions.
4 changes: 3 additions & 1 deletion packages/amesos2/example/SimpleSolve_File.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ int main(int argc, char *argv[]) {
typedef Tpetra::CrsMatrix<>::scalar_type Scalar;
typedef Tpetra::Map<>::local_ordinal_type LO;
typedef Tpetra::Map<>::global_ordinal_type GO;
typedef Tpetra::Map<>::node_type NO;

#if defined(HAVE_AMESOS2_XPETRA) && defined(HAVE_AMESOS2_ZOLTAN2)
typedef Tpetra::Map<>::node_type NO;
typedef Tpetra::RowGraph<LO, GO, NO> Graph;
#endif
typedef Tpetra::CrsMatrix<Scalar,LO,GO> MAT;
typedef Tpetra::MultiVector<Scalar,LO,GO> MV;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ template <typename value_type> int driver(int argc, char *argv[]) {
int dofs_per_node = 1;
bool perturbPivot = false;
int nrhs = 1;
bool randomRHS = true;
bool randomRHS = false;
bool onesRHS = false;
std::string method_name = "chol";
int method = 1; // 1 - Chol, 2 - LDL, 3 - SymLU
Expand Down Expand Up @@ -61,6 +61,7 @@ template <typename value_type> int driver(int argc, char *argv[]) {
opts.set_option<int>("variant", "algorithm variant in levelset scheduling; 0, 1 and 2", &variant);
opts.set_option<int>("nstreams", "# of streams used in CUDA; on host, it is ignored", &nstreams);
opts.set_option<bool>("one-rhs", "Set RHS to be ones", &onesRHS);
opts.set_option<bool>("random-rhs", "Set RHS to be random", &randomRHS);
opts.set_option<bool>("no-warmup", "Flag to turn off warmup", &no_warmup);
opts.set_option<int>("nfacts", "# of factorizations to perform", &nfacts);
opts.set_option<int>("nsolves", "# of solves to perform", &nsolves);
Expand Down
23 changes: 20 additions & 3 deletions packages/shylu/shylu_node/tacho/src/impl/Tacho_Blas_Serial.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ template <typename T> struct BlasSerial {
const T *x, int incx,
const T beta, /* */ T *y, int incy) {

typedef ArithTraits<T> arith_traits;
const T one(1), zero(0);

{
Expand All @@ -53,9 +54,9 @@ template <typename T> struct BlasSerial {
for (int j = 0; j < n; j++) {
T val = 0.0;
for (int i = 0; i < m; i++) {
val += (A[i + j*lda] * x[i*incx]);
val += (arith_traits::conj(A[i + j*lda]) * x[i*incx]);
}
y[j*incx] += alpha * val;
y[j*incy] += alpha * val;
}
}
}
Expand Down Expand Up @@ -103,7 +104,23 @@ template <typename T> struct BlasSerial {
Kokkos::abort("gemm: transb is not valid");
}
} else {
Kokkos::abort("gemm: transa is not valid");
if (transb == 'N' || transb == 'n') {
for (int j = 0; j < n; j++) {
if (beta == zero) {
for (int i = 0; i < m; i++) C[i + j*ldc] = zero;
} else if (beta != one) {
for (int i = 0; i < m; i++) C[i + j*ldc] *= beta;
}
for (int l = 0; l < k; l++) {
T val = alpha * B[l + j*ldb] ;
for (int i = 0; i < m; i++) {
C[i + j*ldc] += (A[l + i*lda] * val);
}
}
}
} else {
Kokkos::abort("gemm: transa is not valid");
}
}
}
}
Expand Down
18 changes: 15 additions & 3 deletions packages/shylu/shylu_node/tacho/src/impl/Tacho_Driver_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ Driver<VT, DT>::Driver()
_h_perm(), _peri(), _h_peri(), _m_graph(0), _nnz_graph(0), _h_ap_graph(), _h_aj_graph(), _h_perm_graph(),
_h_peri_graph(), _nnz_u(0), _nsupernodes(0), _N(nullptr), _verbose(0), _small_problem_thres(1024), _serial_thres_size(-1),
_mb(-1), _nb(-1), _front_update_mode(-1), _levelset(0), _device_level_cut(0), _device_factor_thres(128),
_device_solve_thres(128), _variant(2), _nstreams(16), _pivot_tol(0.0), _max_num_superblocks(-1) {}
_device_solve_thres(128),
#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
_variant(2),
#else
_variant(-1), // sequential by default
#endif
_nstreams(16), _pivot_tol(0.0), _max_num_superblocks(-1) {}

///
/// duplicate the object
Expand Down Expand Up @@ -142,9 +148,15 @@ void Driver<VT, DT>::setLevelSetOptionDeviceFunctionThreshold(const ordinal_type
}

template <typename VT, typename DT> void Driver<VT, DT>::setLevelSetOptionAlgorithmVariant(const ordinal_type variant) {
#if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
if (variant > 3 || variant < 0) {
TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "levelset algorithm variants range from 0 to 3");
}
#else
if (variant > 3 || variant < -1) {
TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "levelset algorithm variants range from -1 to 3 (-1 for serial)");
}
#endif
_variant = variant;
}

Expand Down Expand Up @@ -350,8 +362,8 @@ template <typename VT, typename DT> int Driver<VT, DT>::initialize() {
NumericToolsFactory<VT, DT> factory;
factory.setBaseMember(_method, _m, _ap, _aj, _perm, _peri, _nsupernodes, _supernodes, _gid_super_panel_ptr,
_gid_super_panel_colidx, _sid_super_panel_ptr, _sid_super_panel_colidx,
_blk_super_panel_colidx, _stree_parent, _stree_ptr, _stree_children, _stree_level,
_stree_roots, _verbose);
_blk_super_panel_colidx, _stree_parent, _stree_ptr, _stree_children, _stree_level, _stree_roots,
_verbose);

factory.setLevelSetMember(_variant, _device_level_cut, _device_factor_thres, _device_solve_thres, _nstreams);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ template <typename ValueType> class NumericToolsFactory<ValueType, typename UseT

TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING;
TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER;
// TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER;
TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER;

void setBaseMember(const ordinal_type method,
// input matrix A
Expand All @@ -151,11 +151,37 @@ template <typename ValueType> class NumericToolsFactory<ValueType, typename UseT
void setLevelSetMember(const ordinal_type variant, const ordinal_type device_level_cut,
const ordinal_type device_factor_thres, const ordinal_type device_solve_thres,
const ordinal_type nstreams) {
// TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER;
TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER;
}

void createObject(numeric_tools_base_type *&object) {
KOKKOS_IF_ON_HOST((TACHO_NUMERIC_TOOLS_FACTORY_SERIAL_BODY;))
KOKKOS_IF_ON_HOST((
switch (_variant) {
case -1: {
// sequential code
TACHO_NUMERIC_TOOLS_FACTORY_SERIAL_BODY;
break;
}
case 0: {
TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var0_type);
break;
}
case 1: {
TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var1_type);
break;
}
case 2: {
TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var2_type);
break;
}
case 3: {
TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var3_type);
break;
}
default: {
TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Invalid variant input");
}
}))
}
};
#endif
Expand All @@ -174,10 +200,7 @@ template <typename ValueType> class NumericToolsFactory<ValueType, typename UseT

TACHO_NUMERIC_TOOLS_FACTORY_BASE_USING;
TACHO_NUMERIC_TOOLS_FACTORY_BASE_MEMBER;
#define TACHO_LEVELSET_ON_HOST
#if defined TACHO_LEVELSET_ON_HOST
TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_MEMBER;
#endif

void setBaseMember(const ordinal_type method,
// input matrix A
Expand All @@ -198,15 +221,17 @@ template <typename ValueType> class NumericToolsFactory<ValueType, typename UseT
void setLevelSetMember(const ordinal_type variant, const ordinal_type device_level_cut,
const ordinal_type device_factor_thres, const ordinal_type device_solve_thres,
const ordinal_type nstreams) {
#if defined TACHO_LEVELSET_ON_HOST
TACHO_NUMERIC_TOOLS_FACTORY_SET_LEVELSET_MEMBER;
#endif
}

void createObject(numeric_tools_base_type *&object) {
#if defined TACHO_LEVELSET_ON_HOST
KOKKOS_IF_ON_HOST((
switch (_variant) {
case -1: {
// sequential code
TACHO_NUMERIC_TOOLS_FACTORY_SERIAL_BODY;
break;
}
case 0: {
TACHO_NUMERIC_TOOLS_FACTORY_LEVELSET_BODY(numeric_tools_levelset_var0_type);
break;
Expand All @@ -227,9 +252,6 @@ template <typename ValueType> class NumericToolsFactory<ValueType, typename UseT
TACHO_TEST_FOR_EXCEPTION(true, std::logic_error, "Invalid variant input");
}
}))
#else
KOKKOS_IF_ON_HOST((TACHO_NUMERIC_TOOLS_FACTORY_SERIAL_BODY;))
#endif
}
};
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3312,7 +3312,7 @@ class NumericToolsLevelSet : public NumericToolsBase<ValueType, DeviceType> {

const ordinal_type offm = s.row_begin;
const auto tT = Kokkos::subview(t, range_type(offm, offm + m), Kokkos::ALL());
const UnmanagedViewType<value_type_matrix> bT(bptr, m, nrhs);
const auto bT = Kokkos::subview(b, range_type(0, m), Kokkos::ALL());

ConstUnmanagedViewType<ordinal_type_array> P(_piv.data() + offm * 4, m * 4);
ConstUnmanagedViewType<value_type_matrix> D(_diag.data() + offm * 2, m, 2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ template <typename SupernodeInfoType> struct TeamFunctor_SolveUpperLDL {

const ordinal_type offm = s.row_begin;
const auto tT = Kokkos::subview(_t, range_type(offm, offm + m), Kokkos::ALL());
const UnmanagedViewType<value_type_matrix> bT(bptr, m, nrhs);
const auto bT = Kokkos::subview(b, range_type(0, m), Kokkos::ALL());

ConstUnmanagedViewType<ordinal_type_array> P(_piv.data() + offm * 4, m * 4);
ConstUnmanagedViewType<value_type_matrix> D(_diag.data() + offm * 2, m, 2);
Expand Down
10 changes: 10 additions & 0 deletions packages/shylu/shylu_node/tacho/unit-test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ TRIBITS_ADD_LIBRARY(
NO_INSTALL_LIB_OR_HEADERS
)

TRIBITS_ADD_EXECUTABLE_AND_TEST(
Tacho_TestSolver.x
NOEXESUFFIX
NOEXEPREFIX
SOURCES Tacho_TestSolver.cpp
TESTONLYLIBS tacho-gtest
NUM_MPI_PROCS 1
FAIL_REGULAR_EXPRESSION " FAILED "
)

TRIBITS_ADD_EXECUTABLE_AND_TEST(
Tacho_Test_Util.x
NOEXESUFFIX
Expand Down
Loading

0 comments on commit 9edbfdd

Please sign in to comment.