Skip to content

Commit

Permalink
Refactor relational operator overloading for Eigen::Arrays
Browse files Browse the repository at this point in the history
Remove boilerplate template code by using RelationalOpTraits.
  • Loading branch information
soonho-tri committed Mar 16, 2017
1 parent 31f0367 commit fc03557
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 42 deletions.
74 changes: 32 additions & 42 deletions drake/common/symbolic_formula.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "drake/common/drake_assert.h"
#include "drake/common/drake_copyable.h"
#include "drake/common/eigen_types.h"
#include "drake/common/hash.h"
#include "drake/common/symbolic_environment.h"
#include "drake/common/symbolic_expression.h"
Expand Down Expand Up @@ -280,6 +281,31 @@ const Variables& get_quantified_variables(const Formula& f);
*/
const Formula& get_quantified_formula(const Formula& f);

namespace detail {
/// Provides a return type of relational operations (=, ≠, ≤, <, ≥, >) between
/// `Eigen::Array`s.
///
/// @tparam DerivedA A derived type of Eigen::ArrayBase.
/// @tparam DerivedB A derived type of Eigen::ArrayBase.
/// @pre The type of (DerivedA::Scalar() == DerivedB::Scalar()) is symbolic
/// formula.
template <typename DerivedA, typename DerivedB,
typename = std::enable_if<
std::is_base_of<Eigen::ArrayBase<DerivedA>, DerivedA>::value &&
std::is_base_of<Eigen::ArrayBase<DerivedB>, DerivedB>::value &&
std::is_same<decltype(typename DerivedA::Scalar() ==
typename DerivedB::Scalar()),
Formula>::value>>
struct RelationalOpTraits {
using ReturnType =
Eigen::Array<Formula,
EigenSizeMinPreferFixed<DerivedA::RowsAtCompileTime,
DerivedB::RowsAtCompileTime>::value,
EigenSizeMinPreferFixed<DerivedA::ColsAtCompileTime,
DerivedB::ColsAtCompileTime>::value>;
};
} // namespace detail

/// Returns an Eigen array of symbolic formula where each element includes
/// element-wise symbolic-equality of two arrays @p m1 and @p m2.
///
Expand Down Expand Up @@ -307,13 +333,7 @@ typename std::enable_if<
std::is_same<decltype(typename DerivedA::Scalar() ==
typename DerivedB::Scalar()),
Formula>::value,
Eigen::Array<Formula,
DerivedA::RowsAtCompileTime == Eigen::Dynamic
? DerivedB::RowsAtCompileTime
: DerivedA::RowsAtCompileTime,
DerivedA::ColsAtCompileTime == Eigen::Dynamic
? DerivedB::ColsAtCompileTime
: DerivedA::ColsAtCompileTime>>::type
typename detail::RelationalOpTraits<DerivedA, DerivedB>::ReturnType>::type
operator==(const DerivedA& a1, const DerivedB& a2) {
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(DerivedA, DerivedB);
DRAKE_DEMAND(a1.rows() == a2.rows() && a1.cols() == a2.cols());
Expand All @@ -331,13 +351,7 @@ typename std::enable_if<
std::is_same<decltype(typename DerivedA::Scalar() >=
typename DerivedB::Scalar()),
Formula>::value,
Eigen::Array<Formula,
DerivedA::RowsAtCompileTime == Eigen::Dynamic
? DerivedB::RowsAtCompileTime
: DerivedA::RowsAtCompileTime,
DerivedA::ColsAtCompileTime == Eigen::Dynamic
? DerivedB::ColsAtCompileTime
: DerivedA::ColsAtCompileTime>>::type
typename detail::RelationalOpTraits<DerivedA, DerivedB>::ReturnType>::type
operator<=(const DerivedA& a1, const DerivedB& a2) {
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(DerivedA, DerivedB);
DRAKE_DEMAND(a1.rows() == a2.rows() && a1.cols() == a2.cols());
Expand All @@ -355,13 +369,7 @@ typename std::enable_if<
std::is_same<decltype(typename DerivedA::Scalar() >
typename DerivedB::Scalar()),
Formula>::value,
Eigen::Array<Formula,
DerivedA::RowsAtCompileTime == Eigen::Dynamic
? DerivedB::RowsAtCompileTime
: DerivedA::RowsAtCompileTime,
DerivedA::ColsAtCompileTime == Eigen::Dynamic
? DerivedB::ColsAtCompileTime
: DerivedA::ColsAtCompileTime>>::type
typename detail::RelationalOpTraits<DerivedA, DerivedB>::ReturnType>::type
operator<(const DerivedA& a1, const DerivedB& a2) {
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(DerivedA, DerivedB);
DRAKE_DEMAND(a1.rows() == a2.rows() && a1.cols() == a2.cols());
Expand All @@ -379,13 +387,7 @@ typename std::enable_if<
std::is_same<decltype(typename DerivedA::Scalar() >=
typename DerivedB::Scalar()),
Formula>::value,
Eigen::Array<Formula,
DerivedA::RowsAtCompileTime == Eigen::Dynamic
? DerivedB::RowsAtCompileTime
: DerivedA::RowsAtCompileTime,
DerivedA::ColsAtCompileTime == Eigen::Dynamic
? DerivedB::ColsAtCompileTime
: DerivedA::ColsAtCompileTime>>::type
typename detail::RelationalOpTraits<DerivedA, DerivedB>::ReturnType>::type
operator>=(const DerivedA& a1, const DerivedB& a2) {
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(DerivedA, DerivedB);
DRAKE_DEMAND(a1.rows() == a2.rows() && a1.cols() == a2.cols());
Expand All @@ -403,13 +405,7 @@ typename std::enable_if<
std::is_same<decltype(typename DerivedA::Scalar() >
typename DerivedB::Scalar()),
Formula>::value,
Eigen::Array<Formula,
DerivedA::RowsAtCompileTime == Eigen::Dynamic
? DerivedB::RowsAtCompileTime
: DerivedA::RowsAtCompileTime,
DerivedA::ColsAtCompileTime == Eigen::Dynamic
? DerivedB::ColsAtCompileTime
: DerivedA::ColsAtCompileTime>>::type
typename detail::RelationalOpTraits<DerivedA, DerivedB>::ReturnType>::type
operator>(const DerivedA& a1, const DerivedB& a2) {
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(DerivedA, DerivedB);
DRAKE_DEMAND(a1.rows() == a2.rows() && a1.cols() == a2.cols());
Expand All @@ -427,13 +423,7 @@ typename std::enable_if<
std::is_same<decltype(typename DerivedA::Scalar() >
typename DerivedB::Scalar()),
Formula>::value,
Eigen::Array<Formula,
DerivedA::RowsAtCompileTime == Eigen::Dynamic
? DerivedB::RowsAtCompileTime
: DerivedA::RowsAtCompileTime,
DerivedA::ColsAtCompileTime == Eigen::Dynamic
? DerivedB::ColsAtCompileTime
: DerivedA::ColsAtCompileTime>>::type
typename detail::RelationalOpTraits<DerivedA, DerivedB>::ReturnType>::type
operator!=(const DerivedA& a1, const DerivedB& a2) {
EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(DerivedA, DerivedB);
DRAKE_DEMAND(a1.rows() == a2.rows() && a1.cols() == a2.cols());
Expand Down
17 changes: 17 additions & 0 deletions drake/common/test/symbolic_expression_matrix_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,23 @@ TEST_F(SymbolicExpressionMatrixTest, ArrayOperatorVarOpVar) {
EXPECT_TRUE(CheckArrayOperatorNeq(array_var_2_, array_var_1_));
}

TEST_F(SymbolicExpressionMatrixTest, ArrayOperatorReturnType) {
Eigen::Array<Variable, 2, Eigen::Dynamic> m1(2, 2);
Eigen::Array<Variable, Eigen::Dynamic, 2> m2(2, 2);
EXPECT_TRUE(
(std::is_same<decltype(m1 == m2), Eigen::Array<Formula, 2, 2>>::value));
EXPECT_TRUE(
(std::is_same<decltype(m1 != m2), Eigen::Array<Formula, 2, 2>>::value));
EXPECT_TRUE(
(std::is_same<decltype(m1 <= m2), Eigen::Array<Formula, 2, 2>>::value));
EXPECT_TRUE(
(std::is_same<decltype(m1 < m2), Eigen::Array<Formula, 2, 2>>::value));
EXPECT_TRUE(
(std::is_same<decltype(m1 >= m2), Eigen::Array<Formula, 2, 2>>::value));
EXPECT_TRUE(
(std::is_same<decltype(m1 > m2), Eigen::Array<Formula, 2, 2>>::value));
}

// Checks if m1 == m2 returns a formula which is a conjunction of
// m1(i, j) == m2(i, j) for all i and j.
bool CheckMatrixOperatorEq(const MatrixX<Expression>& m1,
Expand Down

0 comments on commit fc03557

Please sign in to comment.