forked from sammy-tri/drake
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsymbolic_expansion_test.cc
345 lines (292 loc) · 13.4 KB
/
symbolic_expansion_test.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
#include <cmath>
#include <functional>
#include <stdexcept>
#include <utility>
#include <vector>
#include <gtest/gtest.h>
#include "drake/common/symbolic.h"
#include "drake/common/test_utilities/limit_malloc.h"
#include "drake/common/test_utilities/symbolic_test_util.h"
using std::function;
using std::pair;
using std::runtime_error;
using std::vector;
namespace drake {
using test::LimitMalloc;
namespace symbolic {
namespace {
using test::ExprEqual;
using test::ExprNotEqual;
class SymbolicExpansionTest : public ::testing::Test {
protected:
const Variable var_x_{"x"};
const Variable var_y_{"y"};
const Variable var_z_{"z"};
const Expression x_{var_x_};
const Expression y_{var_y_};
const Expression z_{var_z_};
vector<Environment> envs_;
void SetUp() override {
// Set up environments (envs_).
envs_.push_back({{var_x_, 1.7}, {var_y_, 2}, {var_z_, 2.3}}); // + + +
envs_.push_back({{var_x_, -0.3}, {var_y_, 1}, {var_z_, 0.2}}); // - + +
envs_.push_back({{var_x_, 1.4}, {var_y_, -2}, {var_z_, 3.1}}); // + - +
envs_.push_back({{var_x_, 2.2}, {var_y_, 4}, {var_z_, -2.3}}); // + + -
envs_.push_back({{var_x_, -4.7}, {var_y_, -3}, {var_z_, 3.4}}); // - - +
envs_.push_back({{var_x_, 3.1}, {var_y_, -3}, {var_z_, -2.5}}); // + - -
envs_.push_back({{var_x_, -2.8}, {var_y_, 2}, {var_z_, -2.6}}); // _ + -
envs_.push_back({{var_x_, -2.2}, {var_y_, -4}, {var_z_, -2.3}}); // - - -
}
// Check if both e and e.Expand() are evaluated to the close-enough (<eps)
// values under all symbolic environments in envs_.
bool CheckExpandPreserveEvaluation(const Expression& e, const double eps) {
return all_of(envs_.begin(), envs_.end(), [&](const Environment& env) {
return std::fabs(e.Evaluate(env) - e.Expand().Evaluate(env)) < eps;
});
}
// Checks if e == e.Expand().
bool CheckAlreadyExpanded(const Expression& e) {
return e.EqualTo(e.Expand());
}
// Checks if e.Expand() == e.Expand().Expand().
bool CheckExpandIsFixpoint(const Expression& e) {
return e.Expand().EqualTo(e.Expand().Expand());
}
};
TEST_F(SymbolicExpansionTest, ExpressionAlreadyExpandedPolynomial) {
// The following are all already expanded.
EXPECT_TRUE(CheckAlreadyExpanded(0));
EXPECT_TRUE(CheckAlreadyExpanded(1));
EXPECT_TRUE(CheckAlreadyExpanded(-1));
EXPECT_TRUE(CheckAlreadyExpanded(42));
EXPECT_TRUE(CheckAlreadyExpanded(-5));
EXPECT_TRUE(CheckAlreadyExpanded(x_));
EXPECT_TRUE(CheckAlreadyExpanded(-x_));
EXPECT_TRUE(CheckAlreadyExpanded(3 * x_));
EXPECT_TRUE(CheckAlreadyExpanded(-2 * x_));
EXPECT_TRUE(CheckAlreadyExpanded(3 * x_ * y_)); // 3xy
EXPECT_TRUE(CheckAlreadyExpanded(3 * pow(x_, 2) * y_)); // 3x^2y
EXPECT_TRUE(CheckAlreadyExpanded(3 / 10 * pow(x_, 2) * y_)); // 3/10*x^2y
EXPECT_TRUE(CheckAlreadyExpanded(-7 + x_ + y_)); // -7 + x + y
EXPECT_TRUE(CheckAlreadyExpanded(1 + 3 * x_ - 4 * y_)); // 1 + 3x -4y
}
TEST_F(SymbolicExpansionTest, ExpressionAlreadyExpandedPow) {
// The following are all already expanded.
EXPECT_TRUE(CheckAlreadyExpanded(pow(x_, y_))); // x^y
EXPECT_TRUE(CheckAlreadyExpanded(pow(x_, -1))); // x^(-1)
EXPECT_TRUE(CheckAlreadyExpanded(pow(x_, -1))); // x^(-1)
EXPECT_TRUE(CheckAlreadyExpanded(pow(x_ + y_, -1))); // (x + y)^(-1)
EXPECT_TRUE(CheckAlreadyExpanded(pow(x_ + y_, 0.5))); // (x + y)^(0.5)
EXPECT_TRUE(CheckAlreadyExpanded(pow(x_ + y_, 2.5))); // (x + y)^(2.5)
EXPECT_TRUE(CheckAlreadyExpanded(pow(x_ + y_, x_ - y_))); // (x + y)^(x - y)
}
TEST_F(SymbolicExpansionTest, ExpressionExpansion) {
// test_exprs includes pairs of expression `e` and its expected expansion
// `expected`. For each pair (e, expected), we check the following:
// 1. e.Expand() is structurally equal to expected.
// 2. Evaluate e and e.Expand() under multiple environments to check the
// correctness of expansions.
// 3. A expansion is a fixpoint of Expand() function. That is, a expanded
// expression shouldn't be expanded further.
vector<pair<Expression, Expression>> test_exprs;
// (2xy²)² = 4x²y⁴
test_exprs.emplace_back(pow(2 * x_ * y_ * y_, 2),
4 * pow(x_, 2) * pow(y_, 4));
// 5 * (3 + 2y) + 30 * (7 + x_)
// = 15 + 10y + 210 + 30x
// = 225 + 30x + 10y
test_exprs.emplace_back(5 * (3 + 2 * y_) + 30 * (7 + x_),
225 + 30 * x_ + 10 * y_);
// (x + 3y) * (2x + 5y) = 2x^2 + 11xy + 15y^2
test_exprs.emplace_back((x_ + 3 * y_) * (2 * x_ + 5 * y_),
2 * pow(x_, 2) + 11 * x_ * y_ + 15 * pow(y_, 2));
// (7 + x) * (5 + y) * (6 + z)
// = (35 + 5x + 7y + xy) * (6 + z)
// = (210 + 30x + 42y + 6xy) + (35z + 5xz + 7yz + xyz)
test_exprs.emplace_back((7 + x_) * (5 + y_) * (6 + z_),
210 + 30 * x_ + 42 * y_ + 6 * x_ * y_ + 35 * z_ +
5 * x_ * z_ + 7 * y_ * z_ + x_ * y_ * z_);
// (x + 3y) * (2x + 5y) * (x + 3y)
// = (2x^2 + 11xy + 15y^2) * (x + 3y)
// = 2x^3 + 11x^2y + 15xy^2
// + 6x^2y + 33xy^2 + 45y^3
// = 2x^3 + 17x^2y + 48xy^2 + 45y^3
test_exprs.emplace_back((x_ + 3 * y_) * (2 * x_ + 5 * y_) * (x_ + 3 * y_),
2 * pow(x_, 3) + 17 * pow(x_, 2) * y_ +
48 * x_ * pow(y_, 2) + 45 * pow(y_, 3));
// pow((x + y)^2 + 1, (x - y)^2)
// = pow(x^2 + 2xy + y^2 + 1, x^2 -2xy + y^2)
// Expand the base and exponent of pow.
test_exprs.emplace_back(pow(pow(x_ + y_, 2) + 1, pow(x_ - y_, 2)),
pow(pow(x_, 2) + 2 * x_ * y_ + pow(y_, 2) + 1,
pow(x_, 2) - 2 * x_ * y_ + pow(y_, 2)));
// (x + y + 1)^3
// = x^3 + 3x^2y +
// 3x^2 + 3xy^2 + 6xy + 3x +
// y^3 + 3y^2 + 3y + 1
test_exprs.emplace_back(pow(x_ + y_ + 1, 3),
pow(x_, 3) + 3 * pow(x_, 2) * y_ + 3 * pow(x_, 2) +
3 * x_ * pow(y_, 2) + 6 * x_ * y_ + 3 * x_ +
pow(y_, 3) + 3 * pow(y_, 2) + 3 * y_ + 1);
// (x + y + 1)^4
// = 1 + 4x + 4y + 12xy + 12xy^2 +
// 4xy^3 + 12x^2y +
// 6x^2y^2 + 4x^3y +
// 6x^2 + 4x^3 + x^4 +
// 6y^2 + 4y^3 + y^4
test_exprs.emplace_back(
pow(x_ + y_ + 1, 4),
1 + 4 * x_ + 4 * y_ + 12 * x_ * y_ + 12 * x_ * pow(y_, 2) +
4 * x_ * pow(y_, 3) + 12 * pow(x_, 2) * y_ +
6 * pow(x_, 2) * pow(y_, 2) + 4 * pow(x_, 3) * y_ + 6 * pow(x_, 2) +
4 * pow(x_, 3) + pow(x_, 4) + 6 * pow(y_, 2) + 4 * pow(y_, 3) +
pow(y_, 4));
for (const pair<Expression, Expression>& p : test_exprs) {
const Expression& e{p.first};
const Expression expanded{e.Expand()};
const Expression& expected{p.second};
EXPECT_PRED2(ExprEqual, expanded, expected);
EXPECT_TRUE(expanded.is_expanded());
EXPECT_TRUE(CheckExpandPreserveEvaluation(e, 1e-8));
EXPECT_TRUE(CheckExpandIsFixpoint(e));
}
}
TEST_F(SymbolicExpansionTest, MathFunctions) {
// For a math function f(x) and an expression e, we need to have the following
// property:
//
// f(e).Expand() == f(e.Expand())
//
// where '==' is structural equality (Expression::EqualTo).
using F = function<Expression(const Expression&)>;
vector<F> contexts;
contexts.push_back([](const Expression& x) { return log(x); });
contexts.push_back([](const Expression& x) { return abs(x); });
contexts.push_back([](const Expression& x) { return exp(x); });
contexts.push_back([](const Expression& x) { return sqrt(x); });
contexts.push_back([](const Expression& x) { return sin(x); });
contexts.push_back([](const Expression& x) { return cos(x); });
contexts.push_back([](const Expression& x) { return tan(x); });
contexts.push_back([](const Expression& x) { return asin(x); });
contexts.push_back([](const Expression& x) { return acos(x); });
contexts.push_back([](const Expression& x) { return atan(x); });
contexts.push_back([](const Expression& x) { return sinh(x); });
contexts.push_back([](const Expression& x) { return cosh(x); });
contexts.push_back([](const Expression& x) { return tanh(x); });
contexts.push_back([&](const Expression& x) { return min(x, y_); });
contexts.push_back([&](const Expression& x) { return min(y_, x); });
contexts.push_back([&](const Expression& x) { return max(x, z_); });
contexts.push_back([&](const Expression& x) { return max(z_, x); });
contexts.push_back([](const Expression& x) { return ceil(x); });
contexts.push_back([](const Expression& x) { return floor(x); });
contexts.push_back([&](const Expression& x) { return atan2(x, y_); });
contexts.push_back([&](const Expression& x) { return atan2(y_, x); });
vector<Expression> expressions;
expressions.push_back(5 * (3 + 2 * y_) + 30 * (7 + x_));
expressions.push_back((x_ + 3 * y_) * (2 * x_ + 5 * y_));
expressions.push_back((7 + x_) * (5 + y_) * (6 + z_));
expressions.push_back((x_ + 3 * y_) * (2 * x_ + 5 * y_) * (x_ + 3 * y_));
expressions.push_back(pow(pow(x_ + y_, 2) + 1, pow(x_ - y_, 2)));
expressions.push_back(pow(x_ + y_ + 1, 3));
expressions.push_back(pow(x_ + y_ + 1, 4));
for (const F& f : contexts) {
for (const Expression& e : expressions) {
const Expression e1{f(e).Expand()};
const Expression e2{f(e.Expand())};
EXPECT_PRED2(ExprEqual, e1, e2);
EXPECT_TRUE(e1.is_expanded());
EXPECT_TRUE(e2.is_expanded());
EXPECT_TRUE(CheckAlreadyExpanded(e1));
}
}
}
TEST_F(SymbolicExpansionTest, NaN) {
// NaN is considered as not expanded so that ExpressionNaN::Expand() is called
// and throws an exception.
EXPECT_FALSE(Expression::NaN().is_expanded());
// NaN should be detected during expansion and throw runtime_error.
Expression dummy;
EXPECT_THROW(dummy = Expression::NaN().Expand(), runtime_error);
}
TEST_F(SymbolicExpansionTest, IfThenElse) {
const Expression e{if_then_else(x_ > y_, pow(x_ + y_, 2), pow(x_ - y_, 2))};
Expression dummy;
EXPECT_THROW(dummy = e.Expand(), runtime_error);
// An if-then-else expression is considered as not expanded so that
// ExpressionIfThenElse::Expand() is called and throws an exception.
EXPECT_FALSE(e.is_expanded());
}
TEST_F(SymbolicExpansionTest, UninterpretedFunction) {
const Expression uf1{uninterpreted_function("uf1", {})};
EXPECT_PRED2(ExprEqual, uf1, uf1.Expand());
EXPECT_TRUE(uf1.Expand().is_expanded());
const Expression e1{3 * (x_ + y_)};
const Expression e2{pow(x_ + y_, 2)};
const Expression uf2{uninterpreted_function("uf2", {e1, e2})};
EXPECT_PRED2(ExprNotEqual, uf2, uf2.Expand());
EXPECT_TRUE(uf2.Expand().is_expanded());
const Expression uf2_expand_expected{
uninterpreted_function("uf2", {e1.Expand(), e2.Expand()})};
EXPECT_PRED2(ExprEqual, uf2.Expand(), uf2_expand_expected);
}
TEST_F(SymbolicExpansionTest, DivideByConstant) {
// (x) / 2 => 0.5 * x
EXPECT_PRED2(ExprEqual, (x_ / 2).Expand(), 0.5 * x_);
// 3 / 2 => 3 / 2 (no simplification)
EXPECT_PRED2(ExprEqual, (Expression(3.0) / 2).Expand(), 3.0 / 2);
// pow(x, y) / 2 => 0.5 * pow(x, y)
EXPECT_PRED2(ExprEqual, (pow(x_, y_) / 2).Expand(), 0.5 * pow(x_, y_));
// (2x) / 2 => x
EXPECT_PRED2(ExprEqual, ((2 * x_) / 2).Expand(), x_);
// (10x / 5) / 2 => x
EXPECT_PRED2(ExprEqual, (10 * x_ / 5 / 2).Expand(), x_);
// (10x²y³z⁴) / -5 => -2x²y³z⁴
EXPECT_PRED2(ExprEqual,
(10 * pow(x_, 2) * pow(y_, 3) * pow(z_, 4) / -5).Expand(),
-2 * pow(x_, 2) * pow(y_, 3) * pow(z_, 4));
// (36xy / 4 / -3) => -3xy
EXPECT_PRED2(ExprEqual, (36 * x_ * y_ / 4 / -3).Expand(), -3 * x_ * y_);
std::cerr << (x_ / 2).is_polynomial() << std::endl;
// (2x + 4xy + 6) / 2 => x + 2xy + 3
EXPECT_PRED2(ExprEqual, ((2 * x_ + 4 * x_ * y_ + 6) / 2).Expand(),
x_ + 2 * x_ * y_ + 3);
// (4x / 3) * (6y / 2) => 4xy
EXPECT_PRED2(ExprEqual, ((4 * x_ / 3) * (6 * y_ / 2)).Expand(), 4 * x_ * y_);
// (6xy / z / 3) => 2xy / z
EXPECT_PRED2(ExprEqual, (6 * x_ * y_ / z_ / 3).Expand(), 2 * x_ * y_ / z_);
// (36xy / x / -3) => -12xy / x
// Note that we do not cancel x out since it can be zero.
EXPECT_PRED2(ExprEqual, (36 * x_ * y_ / x_ / -3).Expand(),
-12 * x_ * y_ / x_);
}
TEST_F(SymbolicExpansionTest, RepeatedExpandShouldBeNoop) {
const Expression e{(x_ + y_) * (x_ + y_)};
// New ExpressionCells are created here.
const Expression e_expanded{e.Expand()};
{
LimitMalloc guard;
// e_expanded is already expanded, so the following line should not create a
// new cell and no memory allocation should occur. We use LimitMalloc to
// check this claim.
const Expression e_expanded_expanded = e_expanded.Expand();
EXPECT_PRED2(ExprEqual, e_expanded, e_expanded_expanded);
}
}
TEST_F(SymbolicExpansionTest, ExpandMultiplicationsWithDivisions) {
const Expression e1{((x_ + 1) / y_) * (x_ + 3)};
const Expression e2{(x_ + 3) * ((x_ + 1) / z_)};
const Expression e3{((x_ + 1) / (y_ + 6)) * ((x_ + 3) / (z_ + 7))};
const Expression e4{(x_ + y_ / ((z_ + x_) * (y_ + x_))) * (x_ - y_ / z_) *
(x_ * y_ / z_)};
EXPECT_TRUE(CheckExpandIsFixpoint(e1));
EXPECT_TRUE(CheckExpandIsFixpoint(e2));
EXPECT_TRUE(CheckExpandIsFixpoint(e3));
EXPECT_TRUE(CheckExpandIsFixpoint(e4));
EXPECT_TRUE(CheckExpandPreserveEvaluation(e1, 1e-8));
EXPECT_TRUE(CheckExpandPreserveEvaluation(e2, 1e-8));
EXPECT_TRUE(CheckExpandPreserveEvaluation(e3, 1e-8));
EXPECT_TRUE(CheckExpandPreserveEvaluation(e4, 1e-8));
}
} // namespace
} // namespace symbolic
} // namespace drake