Skip to content

Commit

Permalink
linear_solve: Reduce use of SFINAE (RobotLocomotion#15853)
Browse files Browse the repository at this point in the history
* linear_solve: Reduce use of SFINAE

Relevant (maybe?) to: RobotLocomotion#15685

Motivated by discussions in RobotLocomotion#15818, reduce some of the syntax noise by
replacing SFINAE with a combination of type aliases, if constexpr, and
static_assert. No functionality is changed, but perhaps things are
easier to read.
  • Loading branch information
rpoyner-tri authored Sep 30, 2021
1 parent 613a185 commit 62a6763
Showing 1 changed file with 75 additions and 100 deletions.
175 changes: 75 additions & 100 deletions math/linear_solve.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
namespace drake {
namespace math {
namespace internal {
template <typename T>
struct is_double_or_symbolic : std::false_type {};

template <>
struct is_double_or_symbolic<double> : std::true_type {};
template <typename T>
using is_symbolic = std::is_same<T, symbolic::Expression>;

template <>
struct is_double_or_symbolic<symbolic::Expression> : std::true_type {};
template <typename T>
inline constexpr bool is_symbolic_v = is_symbolic<T>::value;

template <typename T>
inline constexpr bool is_double_or_symbolic_v = is_double_or_symbolic<T>::value;
inline constexpr bool is_double_or_symbolic_v =
std::disjunction<std::is_same<T, double>, is_symbolic<T>>::value;

template <typename T>
struct is_autodiff : std::false_type {};
Expand All @@ -30,7 +29,6 @@ struct is_autodiff<drake::AutoDiffd<N>> : std::true_type {};

template <typename T>
inline constexpr bool is_autodiff_v = is_autodiff<T>::value;

} // namespace internal

/**
Expand Down Expand Up @@ -398,7 +396,7 @@ typename std::enable_if<
//@{

namespace internal {
/**
/*
* The return type of GetLinearSolver function. It is the type of the linear
* solver. For example
* LinearSolver<Eigen::LLT, Eigen::Matrix3d>::type is
Expand All @@ -416,6 +414,26 @@ using EigenLinearSolver = LinearSolverType<Eigen::Matrix<
typename DerivedA::Scalar, double>,
DerivedA::RowsAtCompileTime, DerivedA::ColsAtCompileTime, Eigen::ColMajor,
DerivedA::MaxRowsAtCompileTime, DerivedA::MaxColsAtCompileTime>>;

/*
* The most "promoted" of two scalar types; if they are the same, then that
* type, if they are different and one is 'double', then whichever one is not
* 'double'. Otherwise, the resulting type is 'void', indicating an error.
*/
template <typename ScalarA, typename ScalarB>
using Promoted = typename std::conditional_t<
std::is_same_v<ScalarA, ScalarB>, ScalarA,
std::conditional_t<
std::is_same_v<double, ScalarA>, ScalarB,
std::conditional_t<std::is_same_v<double, ScalarB>, ScalarA, void>>>;

/*
* The type of the solution vector 'x' given DerivedA and DerivedB.
*/
template <typename DerivedA, typename DerivedB>
using Solution = typename Eigen::Matrix<
Promoted<typename DerivedA::Scalar, typename DerivedB::Scalar>,
DerivedA::RowsAtCompileTime, DerivedB::ColsAtCompileTime>;
} // namespace internal

/**
Expand Down Expand Up @@ -457,12 +475,11 @@ internal::EigenLinearSolver<LinearSolverType, DerivedA> GetLinearSolver(
* The following table indicate the scalar type of x with A/b containing the
* specified scalar type. The entries with NA are not supported.
*
* | \A | | | |
* | b \ | double | ADS | Expr |
* |------|--------|-----|----- |
* |double| double | ADS | NA |
* | ADS | ADS | ADS | NA |
* | Expr | NA | NA | Expr |
* | b \ A | double | ADS | Expr |
* |--------|--------|-----|----- |
* | double | double | ADS | NA |
* | ADS | ADS | ADS | NA |
* | Expr | NA | NA | Expr |
*
* where ADS stands for Eigen::AutoDiffScalar, and Expr stands for
* symbolic::Expression.
Expand Down Expand Up @@ -516,19 +533,21 @@ internal::EigenLinearSolver<LinearSolverType, DerivedA> GetLinearSolver(

//@{
/**
* Template specialization for both A and b being double-valued matrices.
* See @ref linear_solve for more details.
* Solves system A*x=b. The supported combinations of scalar types are
* summarized in the table above. See @ref linear_solve for more details.
*/
template <template <typename, int...> typename LinearSolverType,
typename DerivedA, typename DerivedB>
typename std::enable_if<
internal::is_double_or_symbolic_v<typename DerivedA::Scalar> &&
internal::is_double_or_symbolic_v<typename DerivedB::Scalar> &&
std::is_same_v<typename DerivedA::Scalar, typename DerivedB::Scalar>,
Eigen::Matrix<typename DerivedA::Scalar, DerivedA::RowsAtCompileTime,
DerivedB::ColsAtCompileTime>>::type
SolveLinearSystem(const Eigen::MatrixBase<DerivedA>& A,
const Eigen::MatrixBase<DerivedB>& b) {
internal::Solution<DerivedA, DerivedB> SolveLinearSystem(
const Eigen::MatrixBase<DerivedA>& A,
const Eigen::MatrixBase<DerivedB>& b) {
using ScalarA = typename DerivedA::Scalar;
using ScalarB = typename DerivedB::Scalar;
static_assert(
std::is_same_v<ScalarA, ScalarB> || (!internal::is_symbolic_v<ScalarA> &&
!internal::is_symbolic_v<ScalarB>),
"Mixing symbolic and other types is not supported.");

const auto linear_solver = GetLinearSolver<LinearSolverType>(A);
return SolveLinearSystem(linear_solver, A, b);
}
Expand All @@ -548,23 +567,6 @@ typename std::enable_if<
return SolveLinearSystem<LinearSolverType>(A, b);
}

/**
* Template specialization for A being double-valued matrix, and b being
* AutoDiffScalar-valued matrix. See @ref linear_solve for more details.
*/
template <template <typename, int...> typename LinearSolverType,
typename DerivedA, typename DerivedB>
typename std::enable_if<
std::is_same_v<typename DerivedA::Scalar, double> &&
internal::is_autodiff_v<typename DerivedB::Scalar>,
Eigen::Matrix<typename DerivedB::Scalar, DerivedA::RowsAtCompileTime,
DerivedB::ColsAtCompileTime>>::type
SolveLinearSystem(const Eigen::MatrixBase<DerivedA>& A,
const Eigen::MatrixBase<DerivedB>& b) {
const auto linear_solver = GetLinearSolver<LinearSolverType>(A);
return SolveLinearSystem(linear_solver, A, b);
}

template <template <typename, int...> typename LinearSolverType,
typename DerivedA, typename DerivedB>
DRAKE_DEPRECATED("2022-01-01",
Expand All @@ -579,22 +581,6 @@ typename std::enable_if<
return SolveLinearSystem<LinearSolverType>(A, b);
}

/**
* Template specialization when A is a matrix of AutoDiffScalar.
* See @ref linear_solve for more details.
*/
template <template <typename, int...> typename LinearSolverType,
typename DerivedA, typename DerivedB>
typename std::enable_if<
internal::is_autodiff_v<typename DerivedA::Scalar>,
Eigen::Matrix<typename DerivedA::Scalar, DerivedA::RowsAtCompileTime,
DerivedB::ColsAtCompileTime>>::type
SolveLinearSystem(const Eigen::MatrixBase<DerivedA>& A,
const Eigen::MatrixBase<DerivedB>& b) {
const auto linear_solver = GetLinearSolver<LinearSolverType>(A);
return SolveLinearSystem(linear_solver, A, b);
}

template <template <typename, int...> typename LinearSolverType,
typename DerivedA, typename DerivedB>
DRAKE_DEPRECATED("2022-01-01",
Expand All @@ -613,12 +599,11 @@ typename std::enable_if<
* Solves a linear system of equations A*x=b.
* Depending on the scalar types of A and b, the scalar type of x is summarized
* in this table.
* | \A | | | |
* | b \ | double | ADS | Expr |
* |------|--------|-----|----- |
* |double| double | ADS | NA |
* | ADS | ADS | ADS | NA |
* | Expr | NA | NA | Expr |
* | b \ A | double | ADS | Expr |
* |--------|--------|-----|----- |
* | double | double | ADS | NA |
* | ADS | ADS | ADS | NA |
* | Expr | NA | NA | Expr |
*
* where ADS stands for Eigen::AutoDiffScalar, and Expr stands for
* symbolic::Expression.
Expand Down Expand Up @@ -654,7 +639,9 @@ template <template <typename, int...> typename LinearSolverType,
typename DerivedA>
class LinearSolver {
public:
using SolverType = internal::EigenLinearSolver<LinearSolverType, DerivedA>;
template <typename DerivedB>
using SolutionType = internal::Solution<DerivedA, DerivedB>;

explicit LinearSolver(const Eigen::MatrixBase<DerivedA>& A)
: linear_solver_{GetLinearSolver<LinearSolverType>(A)} {
if constexpr (internal::is_autodiff_v<typename DerivedA::Scalar>) {
Expand All @@ -663,46 +650,34 @@ class LinearSolver {
}

/**
* Template specialization for both A and b being double- or symbolic-valued
* matrices.
*/
template <typename DerivedB>
typename std::enable_if<
internal::is_double_or_symbolic_v<typename DerivedA::Scalar> &&
std::is_same_v<typename DerivedA::Scalar, typename DerivedB::Scalar>,
Eigen::Matrix<typename DerivedA::Scalar, DerivedA::RowsAtCompileTime,
DerivedB::ColsAtCompileTime>>::type
Solve(const Eigen::MatrixBase<DerivedB>& b) const {
return linear_solver_.solve(b);
}

/**
* Template specialization for A being double-valued matrix, and b being
* AutoDiffScalar-valued matrix.
* Solves system A*x = b.
* Return type is as described in the table above.
*/
template <typename DerivedB>
typename std::enable_if<
std::is_same_v<typename DerivedA::Scalar, double> &&
internal::is_autodiff_v<typename DerivedB::Scalar>,
Eigen::Matrix<typename DerivedB::Scalar, DerivedA::RowsAtCompileTime,
DerivedB::ColsAtCompileTime>>::type
Solve(const Eigen::MatrixBase<DerivedB>& b) const {
return SolveLinearSystem(linear_solver_, b);
}

/**
* Template specialization when A is a matrix of AutoDiffScalar.
*/
template <typename DerivedB>
typename std::enable_if<
internal::is_autodiff_v<typename DerivedA::Scalar>,
Eigen::Matrix<typename DerivedA::Scalar, DerivedA::RowsAtCompileTime,
DerivedB::ColsAtCompileTime>>::type
Solve(const Eigen::MatrixBase<DerivedB>& b) const {
return SolveLinearSystem(linear_solver_, *A_, b);
SolutionType<DerivedB> Solve(const Eigen::MatrixBase<DerivedB>& b) const {
using ScalarA = typename DerivedA::Scalar;
using ScalarB = typename DerivedB::Scalar;
static_assert(std::is_same_v<ScalarA, ScalarB> ||
(!internal::is_symbolic_v<ScalarA> &&
!internal::is_symbolic_v<ScalarB>),
"Mixing symbolic and other types is not supported.");

if constexpr (std::is_same_v<ScalarA, ScalarB> &&
!internal::is_autodiff_v<ScalarA>) {
return linear_solver_.solve(b);
// NOLINTNEXTLINE(readability/braces)
} else if constexpr (std::is_same_v<ScalarA, double> &&
internal::is_autodiff_v<ScalarB>) {
return SolveLinearSystem(linear_solver_, b);
// NOLINTNEXTLINE(readability/braces)
} else if constexpr (internal::is_autodiff_v<ScalarA>) {
return SolveLinearSystem(linear_solver_, *A_, b);
}
DRAKE_UNREACHABLE();
}

private:
using SolverType = internal::EigenLinearSolver<LinearSolverType, DerivedA>;
SolverType linear_solver_;
std::optional<Eigen::Matrix<
typename DerivedA::Scalar, DerivedA::RowsAtCompileTime,
Expand Down

0 comments on commit 62a6763

Please sign in to comment.