Skip to content

Commit

Permalink
Add min/max to symbolic::Expression class and their test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
soonho-tri committed Oct 22, 2016
1 parent 563623a commit 6708c2f
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 13 deletions.
75 changes: 75 additions & 0 deletions drake/common/symbolic_expression.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "drake/common/symbolic_expression.h"

#include <algorithm>
#include <cmath>
#include <functional>
#include <iomanip>
Expand Down Expand Up @@ -742,6 +743,30 @@ ostream& ExpressionTanh::Display(ostream& os) const {

double ExpressionTanh::DoEvaluate(const double v) const { return std::tanh(v); }

ExpressionMin::ExpressionMin(const Expression& e1, const Expression& e2)
: BinaryExpressionCell{ExpressionKind::Min, e1, e2} {}

ostream& ExpressionMin::Display(ostream& os) const {
os << "min(" << get_1st_expression() << ", " << get_2nd_expression() << ")";
return os;
}

double ExpressionMin::DoEvaluate(const double v1, const double v2) const {
return std::min(v1, v2);
}

ExpressionMax::ExpressionMax(const Expression& e1, const Expression& e2)
: BinaryExpressionCell{ExpressionKind::Max, e1, e2} {}

ostream& ExpressionMax::Display(ostream& os) const {
os << "max(" << get_1st_expression() << ", " << get_2nd_expression() << ")";
return os;
}

double ExpressionMax::DoEvaluate(const double v1, const double v2) const {
return std::max(v1, v2);
}

ostream& operator<<(ostream& os, const Expression& e) {
DRAKE_ASSERT(e.ptr_ != nullptr);
return e.ptr_->Display(os);
Expand Down Expand Up @@ -919,5 +944,55 @@ Expression tanh(const Expression& e) {
}
return Expression{make_shared<ExpressionTanh>(e)};
}
Expression min(const Expression& e1, const Expression& e2) {
// simplification #1: min(x, x) -> x
if (e1.EqualTo(e2)) {
return e1;
}
// simplification #2: constant folding
if (e1.get_kind() == ExpressionKind::Constant &&
e2.get_kind() == ExpressionKind::Constant) {
const double v1 =
static_pointer_cast<ExpressionConstant>(e1.ptr_)->get_value();
const double v2 =
static_pointer_cast<ExpressionConstant>(e2.ptr_)->get_value();
return Expression{std::min(v1, v2)};
}
return Expression{make_shared<ExpressionMin>(e1, e2)};
}

Expression min(const Expression& e1, const double v2) {
return min(e1, Expression{v2});
}

Expression min(const double v1, const Expression& e2) {
return min(Expression{v1}, e2);
}

Expression max(const Expression& e1, const Expression& e2) {
// simplification #1: max(x, x) -> x
if (e1.EqualTo(e2)) {
return e1;
}
// simplification #2: constant folding
if (e1.get_kind() == ExpressionKind::Constant &&
e2.get_kind() == ExpressionKind::Constant) {
const double v1 =
static_pointer_cast<ExpressionConstant>(e1.ptr_)->get_value();
const double v2 =
static_pointer_cast<ExpressionConstant>(e2.ptr_)->get_value();
return Expression{std::max(v1, v2)};
}
return Expression{make_shared<ExpressionMax>(e1, e2)};
}

Expression max(const Expression& e1, const double v2) {
return max(e1, Expression{v2});
}

Expression max(const double v1, const Expression& e2) {
return max(Expression{v1}, e2);
}

} // namespace symbolic
} // namespace drake
31 changes: 31 additions & 0 deletions drake/common/symbolic_expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ enum class ExpressionKind {
Sinh, ///< hyperbolic sine
Cosh, ///< hyperbolic cosine
Tanh, ///< hyperbolic tangent
Min, ///< min
Max, ///< max
// TODO(soonho): add Integral
};

Expand All @@ -56,6 +58,7 @@ Its syntax tree is as follows:
E := Var | Constant | - E | E + E | E - E | E * E | E / E | log(E) | abs(E)
| exp(E) | sqrt(E) | pow(E, E) | sin(E) | cos(E) | tan(E) | asin(E)
| acos(E) | atan(E) | atan2(E, E) | sinh(E) | cosh(E) | tanh(E)
| min(E, E) | max(E, E)
\endverbatim
In the implementation, Expression is a simple wrapper including a shared pointer
Expand Down Expand Up @@ -211,6 +214,14 @@ class DRAKE_EXPORT Expression {
friend DRAKE_EXPORT Expression sinh(const Expression& e);
friend DRAKE_EXPORT Expression cosh(const Expression& e);
friend DRAKE_EXPORT Expression tanh(const Expression& e);
friend DRAKE_EXPORT Expression max(double v1, const Expression& e2);
friend DRAKE_EXPORT Expression max(const Expression& e1, double v2);
friend DRAKE_EXPORT Expression max(const Expression& e1,
const Expression& e2);
friend DRAKE_EXPORT Expression min(double v1, const Expression& e2);
friend DRAKE_EXPORT Expression min(const Expression& e1, double v2);
friend DRAKE_EXPORT Expression min(const Expression& e1,
const Expression& e2);

friend DRAKE_EXPORT std::ostream& operator<<(std::ostream& os,
const Expression& e);
Expand Down Expand Up @@ -591,6 +602,26 @@ class ExpressionTanh : public UnaryExpressionCell {
double DoEvaluate(double v) const override;
};

/** Symbolic expression representing min function. */
class ExpressionMin : public BinaryExpressionCell {
public:
explicit ExpressionMin(const Expression& e1, const Expression& e2);
std::ostream& Display(std::ostream& os) const override;

private:
double DoEvaluate(double v1, double v2) const override;
};

/** Symbolic expression representing max function. */
class ExpressionMax : public BinaryExpressionCell {
public:
explicit ExpressionMax(const Expression& e1, const Expression& e2);
std::ostream& Display(std::ostream& os) const override;

private:
double DoEvaluate(double v1, double v2) const override;
};

std::ostream& operator<<(std::ostream& os, const Expression& e);

/** \relates Expression Return a copy of \p lhs updated to record component-wise
Expand Down
147 changes: 134 additions & 13 deletions drake/common/test/symbolic_expression_test.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "drake/common/symbolic_expression.h"

#include <algorithm>
#include <cmath>
#include <functional>
#include <memory>
Expand Down Expand Up @@ -37,6 +38,8 @@ class SymbolicExpressionTest : public ::testing::Test {

const Expression x_plus_y_{x_ + y_};

const Expression x_plus_z_{x_ + z_};

const Expression zero_{0.0};
const Expression one_{1.0};
const Expression two_{2.0};
Expand Down Expand Up @@ -90,20 +93,23 @@ TEST_F(SymbolicExpressionTest, Hash) {
}

TEST_F(SymbolicExpressionTest, HashBinary) {
const Expression e1{x_plus_y_ + x_plus_y_};
const Expression e2{x_plus_y_ - x_plus_y_};
const Expression e3{x_plus_y_ * x_plus_y_};
const Expression e4{x_plus_y_ / x_plus_y_};
const Expression e5{pow(x_plus_y_, x_plus_y_)};
const Expression e6{atan2(x_plus_y_, x_plus_y_)};

// e1, ..., e6 share the same sub-expressions, but their hash values should be
const Expression e1{x_plus_y_ + x_plus_z_};
const Expression e2{x_plus_y_ - x_plus_z_};
const Expression e3{x_plus_y_ * x_plus_z_};
const Expression e4{x_plus_y_ / x_plus_z_};
const Expression e5{pow(x_plus_y_, x_plus_z_)};
const Expression e6{atan2(x_plus_y_, x_plus_z_)};
const Expression e7{min(x_plus_y_, x_plus_z_)};
const Expression e8{max(x_plus_y_, x_plus_z_)};

// e1, ..., e8 share the same sub-expressions, but their hash values should be
// distinct.
unordered_set<size_t> hash_set;
for (auto const& e : {e1, e2, e3, e4, e5, e6}) {
const vector<Expression> exprs{e1, e2, e3, e4, e5, e6, e7, e8};
for (auto const& e : exprs) {
hash_set.insert(e.get_hash());
}
EXPECT_EQ(hash_set.size(), 6u);
EXPECT_EQ(hash_set.size(), exprs.size());
}

TEST_F(SymbolicExpressionTest, HashUnary) {
Expand All @@ -124,11 +130,12 @@ TEST_F(SymbolicExpressionTest, HashUnary) {
// e0, ..., e12 share the same sub-expression, but their hash values should be
// distinct.
unordered_set<size_t> hash_set;
for (auto const& e :
{e0, e1, e2, e3, e4, e5, e6, e7, e8, e9, e10, e11, e12}) {
const vector<Expression> exprs{e0, e1, e2, e3, e4, e5, e6,
e7, e8, e9, e10, e11, e12};
for (auto const& e : exprs) {
hash_set.insert(e.get_hash());
}
EXPECT_EQ(hash_set.size(), 13u);
EXPECT_EQ(hash_set.size(), exprs.size());
}

TEST_F(SymbolicExpressionTest, UnaryMinus) {
Expand Down Expand Up @@ -633,6 +640,120 @@ TEST_F(SymbolicExpressionTest, Tanh) {
std::tanh(2) + std::tanh(3.2));
}

TEST_F(SymbolicExpressionTest, Min1) {
// min(E, E) -> E
EXPECT_TRUE(min(x_plus_y_, x_plus_y_).EqualTo(x_plus_y_));
}

TEST_F(SymbolicExpressionTest, Min2) {
EXPECT_DOUBLE_EQ(min(pi_, pi_).Evaluate(), std::min(3.141592, 3.141592));
EXPECT_DOUBLE_EQ(min(pi_, one_).Evaluate(), std::min(3.141592, 1.0));
EXPECT_DOUBLE_EQ(min(pi_, two_).Evaluate(), std::min(3.141592, 2.0));
EXPECT_DOUBLE_EQ(min(pi_, zero_).Evaluate(), std::min(3.141592, 0.0));
EXPECT_DOUBLE_EQ(min(pi_, neg_one_).Evaluate(), std::min(3.141592, -1.0));
EXPECT_DOUBLE_EQ(min(pi_, neg_pi_).Evaluate(), std::min(3.141592, -3.141592));

EXPECT_DOUBLE_EQ(min(one_, pi_).Evaluate(), std::min(1.0, 3.141592));
EXPECT_DOUBLE_EQ(min(one_, one_).Evaluate(), std::min(1.0, 1.0));
EXPECT_DOUBLE_EQ(min(one_, two_).Evaluate(), std::min(1.0, 2.0));
EXPECT_DOUBLE_EQ(min(one_, zero_).Evaluate(), std::min(1.0, 0.0));
EXPECT_DOUBLE_EQ(min(one_, neg_one_).Evaluate(), std::min(1.0, -1.0));
EXPECT_DOUBLE_EQ(min(one_, neg_pi_).Evaluate(), std::min(1.0, -3.141592));

EXPECT_DOUBLE_EQ(min(two_, pi_).Evaluate(), std::min(2.0, 3.141592));
EXPECT_DOUBLE_EQ(min(two_, one_).Evaluate(), std::min(2.0, 1.0));
EXPECT_DOUBLE_EQ(min(two_, two_).Evaluate(), std::min(2.0, 2.0));
EXPECT_DOUBLE_EQ(min(two_, zero_).Evaluate(), std::min(2.0, 0.0));
EXPECT_DOUBLE_EQ(min(two_, neg_one_).Evaluate(), std::min(2.0, -1.0));
EXPECT_DOUBLE_EQ(min(two_, neg_pi_).Evaluate(), std::min(2.0, -3.141592));

EXPECT_DOUBLE_EQ(min(zero_, pi_).Evaluate(), std::min(0.0, 3.141592));
EXPECT_DOUBLE_EQ(min(zero_, one_).Evaluate(), std::min(0.0, 1.0));
EXPECT_DOUBLE_EQ(min(zero_, two_).Evaluate(), std::min(0.0, 2.0));
EXPECT_DOUBLE_EQ(min(zero_, zero_).Evaluate(), std::min(0.0, 0.0));
EXPECT_DOUBLE_EQ(min(zero_, neg_one_).Evaluate(), std::min(0.0, -1.0));
EXPECT_DOUBLE_EQ(min(zero_, neg_pi_).Evaluate(), std::min(0.0, -3.141592));

EXPECT_DOUBLE_EQ(min(neg_one_, pi_).Evaluate(), std::min(-1.0, 3.141592));
EXPECT_DOUBLE_EQ(min(neg_one_, one_).Evaluate(), std::min(-1.0, 1.0));
EXPECT_DOUBLE_EQ(min(neg_one_, two_).Evaluate(), std::min(-1.0, 2.0));
EXPECT_DOUBLE_EQ(min(neg_one_, zero_).Evaluate(), std::min(-1.0, 0.0));
EXPECT_DOUBLE_EQ(min(neg_one_, neg_one_).Evaluate(), std::min(-1.0, -1.0));
EXPECT_DOUBLE_EQ(min(neg_one_, neg_pi_).Evaluate(),
std::min(-1.0, -3.141592));

EXPECT_DOUBLE_EQ(min(neg_pi_, pi_).Evaluate(), std::min(-3.141592, 3.141592));
EXPECT_DOUBLE_EQ(min(neg_pi_, one_).Evaluate(), std::min(-3.141592, 1.0));
EXPECT_DOUBLE_EQ(min(neg_pi_, two_).Evaluate(), std::min(-3.141592, 2.0));
EXPECT_DOUBLE_EQ(min(neg_pi_, zero_).Evaluate(), std::min(-3.141592, 0.0));
EXPECT_DOUBLE_EQ(min(neg_pi_, neg_one_).Evaluate(),
std::min(-3.141592, -1.0));
EXPECT_DOUBLE_EQ(min(neg_pi_, neg_pi_).Evaluate(),
std::min(-3.141592, -3.141592));

const Expression e{min(x_ * y_ * pi_, sin(x_) + sin(y_))};
const Environment env{{var_x_, 2}, {var_y_, 3.2}};
EXPECT_DOUBLE_EQ(e.Evaluate(env),
std::min(2 * 3.2 * 3.141592, std::sin(2) + std::sin(3.2)));
}

TEST_F(SymbolicExpressionTest, Max1) {
// max(E, E) -> E
EXPECT_TRUE(max(x_plus_y_, x_plus_y_).EqualTo(x_plus_y_));
}

TEST_F(SymbolicExpressionTest, Max2) {
EXPECT_DOUBLE_EQ(max(pi_, pi_).Evaluate(), std::max(3.141592, 3.141592));
EXPECT_DOUBLE_EQ(max(pi_, one_).Evaluate(), std::max(3.141592, 1.0));
EXPECT_DOUBLE_EQ(max(pi_, two_).Evaluate(), std::max(3.141592, 2.0));
EXPECT_DOUBLE_EQ(max(pi_, zero_).Evaluate(), std::max(3.141592, 0.0));
EXPECT_DOUBLE_EQ(max(pi_, neg_one_).Evaluate(), std::max(3.141592, -1.0));
EXPECT_DOUBLE_EQ(max(pi_, neg_pi_).Evaluate(), std::max(3.141592, -3.141592));

EXPECT_DOUBLE_EQ(max(one_, pi_).Evaluate(), std::max(1.0, 3.141592));
EXPECT_DOUBLE_EQ(max(one_, one_).Evaluate(), std::max(1.0, 1.0));
EXPECT_DOUBLE_EQ(max(one_, two_).Evaluate(), std::max(1.0, 2.0));
EXPECT_DOUBLE_EQ(max(one_, zero_).Evaluate(), std::max(1.0, 0.0));
EXPECT_DOUBLE_EQ(max(one_, neg_one_).Evaluate(), std::max(1.0, -1.0));
EXPECT_DOUBLE_EQ(max(one_, neg_pi_).Evaluate(), std::max(1.0, -3.141592));

EXPECT_DOUBLE_EQ(max(two_, pi_).Evaluate(), std::max(2.0, 3.141592));
EXPECT_DOUBLE_EQ(max(two_, one_).Evaluate(), std::max(2.0, 1.0));
EXPECT_DOUBLE_EQ(max(two_, two_).Evaluate(), std::max(2.0, 2.0));
EXPECT_DOUBLE_EQ(max(two_, zero_).Evaluate(), std::max(2.0, 0.0));
EXPECT_DOUBLE_EQ(max(two_, neg_one_).Evaluate(), std::max(2.0, -1.0));
EXPECT_DOUBLE_EQ(max(two_, neg_pi_).Evaluate(), std::max(2.0, -3.141592));

EXPECT_DOUBLE_EQ(max(zero_, pi_).Evaluate(), std::max(0.0, 3.141592));
EXPECT_DOUBLE_EQ(max(zero_, one_).Evaluate(), std::max(0.0, 1.0));
EXPECT_DOUBLE_EQ(max(zero_, two_).Evaluate(), std::max(0.0, 2.0));
EXPECT_DOUBLE_EQ(max(zero_, zero_).Evaluate(), std::max(0.0, 0.0));
EXPECT_DOUBLE_EQ(max(zero_, neg_one_).Evaluate(), std::max(0.0, -1.0));
EXPECT_DOUBLE_EQ(max(zero_, neg_pi_).Evaluate(), std::max(0.0, -3.141592));

EXPECT_DOUBLE_EQ(max(neg_one_, pi_).Evaluate(), std::max(-1.0, 3.141592));
EXPECT_DOUBLE_EQ(max(neg_one_, one_).Evaluate(), std::max(-1.0, 1.0));
EXPECT_DOUBLE_EQ(max(neg_one_, two_).Evaluate(), std::max(-1.0, 2.0));
EXPECT_DOUBLE_EQ(max(neg_one_, zero_).Evaluate(), std::max(-1.0, 0.0));
EXPECT_DOUBLE_EQ(max(neg_one_, neg_one_).Evaluate(), std::max(-1.0, -1.0));
EXPECT_DOUBLE_EQ(max(neg_one_, neg_pi_).Evaluate(),
std::max(-1.0, -3.141592));

EXPECT_DOUBLE_EQ(max(neg_pi_, pi_).Evaluate(), std::max(-3.141592, 3.141592));
EXPECT_DOUBLE_EQ(max(neg_pi_, one_).Evaluate(), std::max(-3.141592, 1.0));
EXPECT_DOUBLE_EQ(max(neg_pi_, two_).Evaluate(), std::max(-3.141592, 2.0));
EXPECT_DOUBLE_EQ(max(neg_pi_, zero_).Evaluate(), std::max(-3.141592, 0.0));
EXPECT_DOUBLE_EQ(max(neg_pi_, neg_one_).Evaluate(),
std::max(-3.141592, -1.0));
EXPECT_DOUBLE_EQ(max(neg_pi_, neg_pi_).Evaluate(),
std::max(-3.141592, -3.141592));

const Expression e{max(x_ * y_ * pi_, sin(x_) + sin(y_))};
const Environment env{{var_x_, 2}, {var_y_, 3.2}};
EXPECT_DOUBLE_EQ(e.Evaluate(env),
std::max(2 * 3.2 * 3.141592, std::sin(2) + std::sin(3.2)));
}

TEST_F(SymbolicExpressionTest, GetVariables) {
const Variables vars1{(x_ + y_ * log(x_ + y_)).GetVariables()};
EXPECT_TRUE(vars1.include(var_x_));
Expand Down

0 comments on commit 6708c2f

Please sign in to comment.