Skip to content

Commit

Permalink
[ad] Add support for comparison operations (RobotLocomotion#17636)
Browse files Browse the repository at this point in the history
  • Loading branch information
jwnimmer-tri authored Aug 1, 2022
1 parent 39e74ce commit 8246466
Show file tree
Hide file tree
Showing 8 changed files with 418 additions and 10 deletions.
6 changes: 6 additions & 0 deletions common/ad/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,16 @@ drake_cc_googletest(
name = "standard_operations_test",
# The test is split into multiple source files to improve compilation time.
srcs = [
"test/standard_operations_abs2_test.cc",
"test/standard_operations_abs_test.cc",
"test/standard_operations_add_test.cc",
"test/standard_operations_cmp_test.cc",
"test/standard_operations_dec_test.cc",
"test/standard_operations_div_test.cc",
"test/standard_operations_inc_test.cc",
"test/standard_operations_integer_test.cc",
"test/standard_operations_max_test.cc",
"test/standard_operations_min_test.cc",
"test/standard_operations_mul_test.cc",
"test/standard_operations_stream_test.cc",
"test/standard_operations_sub_test.cc",
Expand All @@ -82,6 +87,7 @@ drake_cc_googletest(
# autodiff class to be tested.
"-DDRAKE_AUTODIFFXD_DUT=drake::ad::AutoDiff",
],
shard_count = 16,
deps = [
":auto_diff",
":standard_operations_test_h",
Expand Down
201 changes: 200 additions & 1 deletion common/ad/internal/standard_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@

/* This file contains free function operators for Drake's AutoDiff type.
The functions provide arithmetic (+,-,*,/) for now and more to come later.
The functions provide not only arithmetic (+,-,*,/) and boolean comparison
(<,<=,>,>=,==,!=) but also argument-dependent lookup ("ADL") compatiblity with
the standard library's mathematical functions (abs, etc.)
(See https://en.cppreference.com/w/cpp/language/adl for details about ADL.)
A few functions for Eigen::numext are also added to argument-dependent lookup.
Functions that cannot preserve gradients will return a primitive type (`bool`
or `double`) instead of an AutoDiff.
NOTE: This file should never be included directly, rather only from
auto_diff.h in a very specific order. */
Expand Down Expand Up @@ -227,6 +236,196 @@ inline AutoDiff operator/(double a, const AutoDiff& b) {

//@}

/// @name Comparison operators
///
/// https://en.cppreference.com/w/cpp/language/operators#Comparison_operators
//@{

/** Standard comparison operator. Discards the derivatives. */
inline bool operator<(const AutoDiff& a, const AutoDiff& b) {
return a.value() < b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator<=(const AutoDiff& a, const AutoDiff& b) {
return a.value() <= b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator>(const AutoDiff& a, const AutoDiff& b) {
return a.value() > b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator>=(const AutoDiff& a, const AutoDiff& b) {
return a.value() >= b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator==(const AutoDiff& a, const AutoDiff& b) {
return a.value() == b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator!=(const AutoDiff& a, const AutoDiff& b) {
return a.value() != b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator<(const AutoDiff& a, double b) {
return a.value() < b;
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator<=(const AutoDiff& a, double b) {
return a.value() <= b;
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator>(const AutoDiff& a, double b) {
return a.value() > b;
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator>=(const AutoDiff& a, double b) {
return a.value() >= b;
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator==(const AutoDiff& a, double b) {
return a.value() == b;
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator!=(const AutoDiff& a, double b) {
return a.value() != b;
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator<(double a, const AutoDiff& b) {
return a < b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator<=(double a, const AutoDiff& b) {
return a <= b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator>(double a, const AutoDiff& b) {
return a > b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator>=(double a, const AutoDiff& b) {
return a >= b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator==(double a, const AutoDiff& b) {
return a == b.value();
}

/** Standard comparison operator. Discards the derivatives. */
inline bool operator!=(double a, const AutoDiff& b) {
return a != b.value();
}

//@}

/// @name Minimum/maximum operations
///
/// https://en.cppreference.com/w/cpp/algorithm#Minimum.2Fmaximum_operations
//@{

/** ADL overload to mimic std::max from <algorithm>.
Note that like std::max, this function returns a reference to whichever
argument was chosen; it does not make a copy. When `a` and `b` are equal,
retains the derivatives of `a` (by returning `a`) unless `a` has empty
derivatives, in which case `b` is returned. */
inline const AutoDiff& max(const AutoDiff& a, const AutoDiff& b) {
if (a.value() == b.value()) {
return a.derivatives().size() > 0 ? a : b;
}
return a.value() < b.value() ? b : a;
}

/** ADL overload to mimic std::max from <algorithm>.
When `a` and `b` are equal, retains the derivatives of `a`. */
inline AutoDiff max(AutoDiff a, double b) {
if (a.value() < b) {
a = b;
}
return a;
}

/** ADL overload to mimic std::max from <algorithm>.
When `a` and `b` are equal, retains the derivatives of `b`. */
inline AutoDiff max(double a, AutoDiff b) {
if (a < b.value()) {
return b;
}
b = a;
return b;
}

/** ADL overload to mimic std::min from <algorithm>.
Note that like std::min, this function returns a reference to whichever
argument was chosen; it does not make a copy. When `a` and `b` are equal,
retains the derivatives of `a` (by returning `a`) unless `a` has empty
derivatives, in which case `b` is returned. */
inline const AutoDiff& min(const AutoDiff& a, const AutoDiff& b) {
if (a.value() == b.value()) {
return a.derivatives().size() > 0 ? a : b;
}
return b.value() < a.value() ? b : a;
}

/** ADL overload to mimic std::min from <algorithm>.
When `a` and `b` are equal, retains the derivatives of `a`. */
inline AutoDiff min(AutoDiff a, double b) {
if (b < a.value()) {
a = b;
}
return a;
}

/** ADL overload to mimic std::min from <algorithm>.
When `a` and `b` are equal, retains the derivatives of `b`. */
// NOLINTNEXTLINE(build/include_what_you_use) false positive.
inline AutoDiff min(double a, AutoDiff b) {
if (a < b.value()) {
b = a;
}
return b;
}

//@}

/// @name Math functions: Basic operations
///
/// https://en.cppreference.com/w/cpp/numeric/math#Basic_operations
//@{

/** ADL overload to mimic std::abs from <cmath>. */
inline AutoDiff abs(AutoDiff x) {
// Conditionally negate negative numbers.
if (x.value() < 0) {
x *= -1;
}
return x;
}

/** ADL overload to mimic Eigen::numext::abs2. */
inline AutoDiff abs2(AutoDiff x) {
// ∂/∂x x² = 2x
x.partials().Mul(2 * x.value());
x.value() *= x.value();
return x;
}

//@}

/// @name Math functions: Nearest integer floating point operations
///
/// https://en.cppreference.com/w/cpp/numeric/math#Nearest_integer_floating_point_operations
Expand Down
17 changes: 17 additions & 0 deletions common/ad/test/standard_operations_abs2_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Abs2) {
CHECK_UNARY_FUNCTION(abs2, x, y, 0.1);
CHECK_UNARY_FUNCTION(abs2, x, y, -0.1);
CHECK_UNARY_FUNCTION(abs2, y, x, 0.1);
CHECK_UNARY_FUNCTION(abs2, y, x, -0.1);
}

} // namespace
} // namespace test
} // namespace drake
17 changes: 17 additions & 0 deletions common/ad/test/standard_operations_abs_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Abs) {
CHECK_UNARY_FUNCTION(abs, x, y, 1.1);
CHECK_UNARY_FUNCTION(abs, x, y, -1.1);
CHECK_UNARY_FUNCTION(abs, y, x, 1.1);
CHECK_UNARY_FUNCTION(abs, y, x, -1.1);
}

} // namespace
} // namespace test
} // namespace drake
49 changes: 49 additions & 0 deletions common/ad/test/standard_operations_cmp_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include <sstream>

#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

#define DRAKE_CHECK_CMP(cmp) \
EXPECT_EQ(0 cmp 1, AutoDiffDut{0} cmp AutoDiffDut{1}); \
EXPECT_EQ(1 cmp 0, AutoDiffDut{1} cmp AutoDiffDut{0}); \
EXPECT_EQ(1 cmp 1, AutoDiffDut{1} cmp AutoDiffDut{1}); \
EXPECT_EQ(0 cmp 1, AutoDiffDut{0} cmp 1); /* NOLINT */ \
EXPECT_EQ(1 cmp 0, AutoDiffDut{1} cmp 0); /* NOLINT */ \
EXPECT_EQ(1 cmp 1, AutoDiffDut{1} cmp 1); /* NOLINT */ \
EXPECT_EQ(0 cmp 1, 0 cmp AutoDiffDut{1}); \
EXPECT_EQ(1 cmp 0, 1 cmp AutoDiffDut{0}); \
EXPECT_EQ(1 cmp 1, 1 cmp AutoDiffDut{1})

TEST_F(StandardOperationsTest, CmpLt) {
DRAKE_CHECK_CMP(<); // NOLINT
}

TEST_F(StandardOperationsTest, CmpLe) {
DRAKE_CHECK_CMP(<=);
}

TEST_F(StandardOperationsTest, CmpGt) {
DRAKE_CHECK_CMP(>); // NOLINT
}

TEST_F(StandardOperationsTest, CmpGe) {
DRAKE_CHECK_CMP(>=);
}

TEST_F(StandardOperationsTest, CmpEq) {
DRAKE_CHECK_CMP(==);
}

TEST_F(StandardOperationsTest, CmpNe) {
DRAKE_CHECK_CMP(!=);
}

#undef DRAKE_CHECK_CMP

} // namespace
} // namespace test
} // namespace drake
53 changes: 53 additions & 0 deletions common/ad/test/standard_operations_max_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, MaxBothAds) {
CHECK_BINARY_FUNCTION_ADS_ADS(max, x, y, 0.3);
CHECK_BINARY_FUNCTION_ADS_ADS(max, x, y, -0.3);
CHECK_BINARY_FUNCTION_ADS_ADS(max, y, x, 0.4);
CHECK_BINARY_FUNCTION_ADS_ADS(max, y, x, -0.4);
}

TEST_F(StandardOperationsTest, MaxLhsAds) {
CHECK_BINARY_FUNCTION_ADS_SCALAR(max, x, y, 0.3);
CHECK_BINARY_FUNCTION_ADS_SCALAR(max, x, y, -0.3);
CHECK_BINARY_FUNCTION_ADS_SCALAR(max, y, x, 0.4);
CHECK_BINARY_FUNCTION_ADS_SCALAR(max, y, x, -0.4);
}

TEST_F(StandardOperationsTest, MaxRhsAds) {
CHECK_BINARY_FUNCTION_SCALAR_ADS(max, x, y, 0.3);
CHECK_BINARY_FUNCTION_SCALAR_ADS(max, x, y, -0.3);
CHECK_BINARY_FUNCTION_SCALAR_ADS(max, y, x, 0.4);
CHECK_BINARY_FUNCTION_SCALAR_ADS(max, y, x, -0.4);
}

TEST_F(StandardOperationsTest, TieBreakingCheckMaxBothNonEmpty) {
// Given `max(v1, v2)`, our overload returns the first argument `v1` when
// `v1 == v2` holds if both `v1` and `v2` have non-empty derivatives. In
// Drake, we rely on this implementation-detail. This test checks if the
// property holds so that we can detect a possible change in future.
const AutoDiffDut v1{1.0, Vector1<double>(3.)};
const AutoDiffDut v2{1.0, Vector1<double>(2.)};
EXPECT_EQ(max(v1, v2).derivatives()[0], 3.0); // Returns v1, not v2.
}

TEST_F(StandardOperationsTest, TieBreakingCheckMaxOneNonEmpty) {
// Given `max(v1, v2)`, our overload returns whichever argument has non-empty
// derivatives in the case where only one has non-empty derivatives. In
// Drake, we rely on this implementation-detail. This test checks if the
// property holds so that we can detect a possible change in future.
const AutoDiffDut v1{1.0};
const AutoDiffDut v2{1.0, Vector1<double>(2.)};
EXPECT_TRUE(CompareMatrices(min(v1, v2).derivatives(),
Vector1<double>(2.))); // Returns v2, not v1.
EXPECT_TRUE(CompareMatrices(min(v2, v1).derivatives(),
Vector1<double>(2.))); // Returns v2, not v1.
}
} // namespace
} // namespace test
} // namespace drake
Loading

0 comments on commit 8246466

Please sign in to comment.