Skip to content

Commit

Permalink
[symbolic] Add ExpressionCell::EvaluatePartial for performance (Robot…
Browse files Browse the repository at this point in the history
…Locomotion#17727)

* [symbolic] Add ExpressionCell::EvaluatePartial for performance
  • Loading branch information
jwnimmer-tri authored Aug 24, 2022
1 parent 11d8989 commit f9fc847
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 11 deletions.
15 changes: 4 additions & 11 deletions common/symbolic/expression/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,10 @@ Eigen::SparseMatrix<double> Evaluate(
}

Expression Expression::EvaluatePartial(const Environment& env) const {
if (env.empty()) {
if (is_constant(*this) || env.empty()) {
return *this;
}
Substitution subst;
for (const pair<const Variable, double>& p : env) {
subst.emplace(p.first, p.second);
}
return Substitute(subst);
return cell().EvaluatePartial(env);
}

Expression Expression::Expand() const {
Expand Down Expand Up @@ -213,13 +209,10 @@ Expression Expression::Substitute(const Variable& var,
}

Expression Expression::Substitute(const Substitution& s) const {
if (is_constant(*this)) {
if (is_constant(*this) || s.empty()) {
return *this;
}
if (!s.empty()) {
return cell().Substitute(s);
}
return *this;
return cell().Substitute(s);
}

Expression Expression::Differentiate(const Variable& x) const {
Expand Down
138 changes: 138 additions & 0 deletions common/symbolic/expression/expression_cell.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,14 @@ double ExpressionVar::Evaluate(const Environment& env) const {

Expression ExpressionVar::Expand() const { return Expression{var_}; }

Expression ExpressionVar::EvaluatePartial(const Environment& env) const {
const Environment::const_iterator it{env.find(var_)};
if (it != env.end()) {
return it->second;
}
return Expression{var_};
}

Expression ExpressionVar::Substitute(const Substitution& s) const {
const Substitution::const_iterator it{s.find(var_)};
if (it != s.end()) {
Expand Down Expand Up @@ -411,6 +419,10 @@ Expression ExpressionNaN::Expand() const {
throw runtime_error("NaN is detected during expansion.");
}

Expression ExpressionNaN::EvaluatePartial(const Environment&) const {
throw runtime_error("NaN is detected during environment substitution.");
}

Expression ExpressionNaN::Substitute(const Substitution&) const {
throw runtime_error("NaN is detected during substitution.");
}
Expand Down Expand Up @@ -513,6 +525,15 @@ Expression ExpressionAdd::Expand() const {
return fac.GetExpression();
}

Expression ExpressionAdd::EvaluatePartial(const Environment& env) const {
return accumulate(
expr_to_coeff_map_.begin(), expr_to_coeff_map_.end(),
Expression{constant_},
[&env](const Expression& init, const pair<const Expression, double>& p) {
return init + p.first.EvaluatePartial(env) * p.second;
});
}

Expression ExpressionAdd::Substitute(const Substitution& s) const {
return accumulate(
expr_to_coeff_map_.begin(), expr_to_coeff_map_.end(),
Expand Down Expand Up @@ -790,6 +811,16 @@ Expression ExpressionMul::Expand() const {
});
}

Expression ExpressionMul::EvaluatePartial(const Environment& env) const {
return accumulate(base_to_exponent_map_.begin(), base_to_exponent_map_.end(),
Expression{constant_},
[&env](const Expression& init,
const pair<const Expression, Expression>& p) {
return init * pow(p.first.EvaluatePartial(env),
p.second.EvaluatePartial(env));
});
}

Expression ExpressionMul::Substitute(const Substitution& s) const {
return accumulate(base_to_exponent_map_.begin(), base_to_exponent_map_.end(),
Expression{constant_},
Expand Down Expand Up @@ -1204,6 +1235,11 @@ Expression ExpressionDiv::Expand() const {
}
}

Expression ExpressionDiv::EvaluatePartial(const Environment& env) const {
return get_first_argument().EvaluatePartial(env) /
get_second_argument().EvaluatePartial(env);
}

Expression ExpressionDiv::Substitute(const Substitution& s) const {
return get_first_argument().Substitute(s) /
get_second_argument().Substitute(s);
Expand Down Expand Up @@ -1248,6 +1284,10 @@ Expression ExpressionLog::Expand() const {
return log(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionLog::EvaluatePartial(const Environment& env) const {
return log(get_argument().EvaluatePartial(env));
}

Expression ExpressionLog::Substitute(const Substitution& s) const {
return log(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1275,6 +1315,10 @@ Expression ExpressionAbs::Expand() const {
return abs(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionAbs::EvaluatePartial(const Environment& env) const {
return abs(get_argument().EvaluatePartial(env));
}

Expression ExpressionAbs::Substitute(const Substitution& s) const {
return abs(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1304,6 +1348,10 @@ Expression ExpressionExp::Expand() const {
return exp(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionExp::EvaluatePartial(const Environment& env) const {
return exp(get_argument().EvaluatePartial(env));
}

Expression ExpressionExp::Substitute(const Substitution& s) const {
return exp(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1337,6 +1385,10 @@ Expression ExpressionSqrt::Expand() const {
return sqrt(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionSqrt::EvaluatePartial(const Environment& env) const {
return sqrt(get_argument().EvaluatePartial(env));
}

Expression ExpressionSqrt::Substitute(const Substitution& s) const {
return sqrt(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1379,6 +1431,11 @@ Expression ExpressionPow::Expand() const {
e2.is_expanded() ? e2 : e2.Expand());
}

Expression ExpressionPow::EvaluatePartial(const Environment& env) const {
return pow(get_first_argument().EvaluatePartial(env),
get_second_argument().EvaluatePartial(env));
}

Expression ExpressionPow::Substitute(const Substitution& s) const {
return pow(get_first_argument().Substitute(s),
get_second_argument().Substitute(s));
Expand Down Expand Up @@ -1406,6 +1463,10 @@ Expression ExpressionSin::Expand() const {
return sin(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionSin::EvaluatePartial(const Environment& env) const {
return sin(get_argument().EvaluatePartial(env));
}

Expression ExpressionSin::Substitute(const Substitution& s) const {
return sin(get_argument().Substitute(s));
}
Expand All @@ -1430,6 +1491,10 @@ Expression ExpressionCos::Expand() const {
return cos(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionCos::EvaluatePartial(const Environment& env) const {
return cos(get_argument().EvaluatePartial(env));
}

Expression ExpressionCos::Substitute(const Substitution& s) const {
return cos(get_argument().Substitute(s));
}
Expand All @@ -1454,6 +1519,10 @@ Expression ExpressionTan::Expand() const {
return tan(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionTan::EvaluatePartial(const Environment& env) const {
return tan(get_argument().EvaluatePartial(env));
}

Expression ExpressionTan::Substitute(const Substitution& s) const {
return tan(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1487,6 +1556,10 @@ Expression ExpressionAsin::Expand() const {
return asin(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionAsin::EvaluatePartial(const Environment& env) const {
return asin(get_argument().EvaluatePartial(env));
}

Expression ExpressionAsin::Substitute(const Substitution& s) const {
return asin(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1523,6 +1596,10 @@ Expression ExpressionAcos::Expand() const {
return acos(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionAcos::EvaluatePartial(const Environment& env) const {
return acos(get_argument().EvaluatePartial(env));
}

Expression ExpressionAcos::Substitute(const Substitution& s) const {
return acos(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1550,6 +1627,10 @@ Expression ExpressionAtan::Expand() const {
return atan(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionAtan::EvaluatePartial(const Environment& env) const {
return atan(get_argument().EvaluatePartial(env));
}

Expression ExpressionAtan::Substitute(const Substitution& s) const {
return atan(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1577,6 +1658,11 @@ Expression ExpressionAtan2::Expand() const {
e2.is_expanded() ? e2 : e2.Expand());
}

Expression ExpressionAtan2::EvaluatePartial(const Environment& env) const {
return atan2(get_first_argument().EvaluatePartial(env),
get_second_argument().EvaluatePartial(env));
}

Expression ExpressionAtan2::Substitute(const Substitution& s) const {
return atan2(get_first_argument().Substitute(s),
get_second_argument().Substitute(s));
Expand Down Expand Up @@ -1607,6 +1693,10 @@ Expression ExpressionSinh::Expand() const {
return sinh(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionSinh::EvaluatePartial(const Environment& env) const {
return sinh(get_argument().EvaluatePartial(env));
}

Expression ExpressionSinh::Substitute(const Substitution& s) const {
return sinh(get_argument().Substitute(s));
}
Expand All @@ -1631,6 +1721,10 @@ Expression ExpressionCosh::Expand() const {
return cosh(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionCosh::EvaluatePartial(const Environment& env) const {
return cosh(get_argument().EvaluatePartial(env));
}

Expression ExpressionCosh::Substitute(const Substitution& s) const {
return cosh(get_argument().Substitute(s));
}
Expand All @@ -1655,6 +1749,10 @@ Expression ExpressionTanh::Expand() const {
return tanh(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionTanh::EvaluatePartial(const Environment& env) const {
return tanh(get_argument().EvaluatePartial(env));
}

Expression ExpressionTanh::Substitute(const Substitution& s) const {
return tanh(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1682,6 +1780,11 @@ Expression ExpressionMin::Expand() const {
e2.is_expanded() ? e2 : e2.Expand());
}

Expression ExpressionMin::EvaluatePartial(const Environment& env) const {
return min(get_first_argument().EvaluatePartial(env),
get_second_argument().EvaluatePartial(env));
}

Expression ExpressionMin::Substitute(const Substitution& s) const {
return min(get_first_argument().Substitute(s),
get_second_argument().Substitute(s));
Expand Down Expand Up @@ -1720,6 +1823,11 @@ Expression ExpressionMax::Expand() const {
e2.is_expanded() ? e2 : e2.Expand());
}

Expression ExpressionMax::EvaluatePartial(const Environment& env) const {
return max(get_first_argument().EvaluatePartial(env),
get_second_argument().EvaluatePartial(env));
}

Expression ExpressionMax::Substitute(const Substitution& s) const {
return max(get_first_argument().Substitute(s),
get_second_argument().Substitute(s));
Expand Down Expand Up @@ -1754,6 +1862,10 @@ Expression ExpressionCeiling::Expand() const {
return ceil(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionCeiling::EvaluatePartial(const Environment& env) const {
return ceil(get_argument().EvaluatePartial(env));
}

Expression ExpressionCeiling::Substitute(const Substitution& s) const {
return ceil(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1786,6 +1898,10 @@ Expression ExpressionFloor::Expand() const {
return floor(arg.is_expanded() ? arg : arg.Expand());
}

Expression ExpressionFloor::EvaluatePartial(const Environment& env) const {
return floor(get_argument().EvaluatePartial(env));
}

Expression ExpressionFloor::Substitute(const Substitution& s) const {
return floor(get_argument().Substitute(s));
}
Expand Down Expand Up @@ -1875,6 +1991,18 @@ Expression ExpressionIfThenElse::Expand() const {
throw runtime_error("Not yet implemented.");
}

Expression ExpressionIfThenElse::EvaluatePartial(const Environment& env) const {
// TODO(jwnimmer-tri) We could define a Formula::EvaluatePartial for improved
// performance, if necessary.
Substitution subst;
for (const pair<const Variable, double>& p : env) {
subst.emplace(p.first, p.second);
}
return if_then_else(f_cond_.Substitute(subst),
e_then_.EvaluatePartial(env),
e_else_.EvaluatePartial(env));
}

Expression ExpressionIfThenElse::Substitute(const Substitution& s) const {
return if_then_else(f_cond_.Substitute(s), e_then_.Substitute(s),
e_else_.Substitute(s));
Expand Down Expand Up @@ -1984,6 +2112,16 @@ Expression ExpressionUninterpretedFunction::Expand() const {
return uninterpreted_function(name_, std::move(new_arguments));
}

Expression ExpressionUninterpretedFunction::EvaluatePartial(
const Environment& env) const {
vector<Expression> new_arguments;
new_arguments.reserve(arguments_.size());
for (const Expression& arg : arguments_) {
new_arguments.push_back(arg.EvaluatePartial(env));
}
return uninterpreted_function(name_, std::move(new_arguments));
}

Expression ExpressionUninterpretedFunction::Substitute(
const Substitution& s) const {
vector<Expression> new_arguments;
Expand Down
Loading

0 comments on commit f9fc847

Please sign in to comment.