Skip to content

Commit

Permalink
[ad] Add support for arithmetic operations (RobotLocomotion#17622)
Browse files Browse the repository at this point in the history
  • Loading branch information
jwnimmer-tri authored Jul 28, 2022
1 parent dc7af96 commit d5ce331
Show file tree
Hide file tree
Showing 34 changed files with 502 additions and 36 deletions.
8 changes: 7 additions & 1 deletion common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -645,10 +645,16 @@ drake_cc_googletest(
"test/autodiffxd_subtraction_test.cc",
"test/autodiffxd_tan_test.cc",
"test/autodiffxd_tanh_test.cc",
"test/autodiffxd_test.h",
],
copts = [
# The test fixture at //common/ad:standard_operations_test_h requires
# some configuration for the specific autodiff class to be tested.
"-DDRAKE_AUTODIFFXD_DUT=drake::AutoDiffXd",
"-DStandardOperationsTest=AutoDiffXdTest",
],
deps = [
":autodiff",
"//common/ad:standard_operations_test_h",
"//common/test_utilities:eigen_matrix_compare",
],
)
Expand Down
26 changes: 26 additions & 0 deletions common/ad/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,40 @@ drake_cc_googletest(
],
)

drake_cc_library(
name = "standard_operations_test_h",
testonly = True,
hdrs = ["test/standard_operations_test.h"],
visibility = [
# TODO(jwnimmer-tri) Once drake/common/autodiffxd.h is finally deleted,
# we can also delete its unit test and thus withdraw this visibility.
"//common:__pkg__",
],
deps = [
"//common/test_utilities:eigen_matrix_compare",
],
)

drake_cc_googletest(
name = "standard_operations_test",
# The test is split into multiple source files to improve compilation time.
srcs = [
"test/standard_operations_add_test.cc",
"test/standard_operations_dec_test.cc",
"test/standard_operations_div_test.cc",
"test/standard_operations_inc_test.cc",
"test/standard_operations_mul_test.cc",
"test/standard_operations_stream_test.cc",
"test/standard_operations_sub_test.cc",
],
copts = [
# The test fixture at requires some configuration for the specific
# autodiff class to be tested.
"-DDRAKE_AUTODIFFXD_DUT=drake::ad::AutoDiff",
],
deps = [
":auto_diff",
":standard_operations_test_h",
],
)

Expand Down
13 changes: 13 additions & 0 deletions common/ad/auto_diff.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ class AutoDiff {
return partials_.get_raw_storage_mutable();
}

/// @name Internal use only
//@{

/** (Internal use only)
Users should call derivatives() instead. */
const internal::Partials& partials() const { return partials_; }

/** (Internal use only)
Users should call derivatives() instead. */
internal::Partials& partials() { return partials_; }

//@}

private:
double value_{0.0};
internal::Partials partials_;
Expand Down
205 changes: 205 additions & 0 deletions common/ad/internal/standard_operations.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,220 @@
#pragma once

#include <cmath>
#include <iosfwd>

/* This file contains free function operators for Drake's AutoDiff type.
The functions provide arithmetic (+,-,*,/) for now and more to come later.
NOTE: This file should never be included directly, rather only from
auto_diff.h in a very specific order. */

namespace drake {
namespace ad {

/// @name Increment and decrement
///
/// https://en.cppreference.com/w/cpp/language/operators#Increment_and_decrement
//@{

/** Standard prefix increment operator (i.e., `++x`). */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff& operator++(AutoDiff& x) {
++x.value();
return x;
}

/** Standard postfix increment operator (i.e., `x++`). */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff operator++(AutoDiff& x, int) {
AutoDiff result = x;
++x.value();
return result;
}

/** Standard prefix decrement operator (i.e., `--x`). */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff& operator--(AutoDiff& x) {
--x.value();
return x;
}

/** Standard postfix decrement operator (i.e., `x--`). */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff operator--(AutoDiff& x, int) {
AutoDiff result = x;
--x.value();
return result;
}

//@}

/// @name Arithmetic operators
///
/// https://en.cppreference.com/w/cpp/language/operators#Binary_arithmetic_operators
//@{

/** Standard compound addition and assignment operator. */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff& operator+=(AutoDiff& a, const AutoDiff& b) {
// ∂/∂x a + b = a' + b'
a.partials().Add(b.partials());
a.value() += b.value();
return a;
}

/** Standard compound addition and assignment operator. */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff& operator+=(AutoDiff& a, double b) {
// ∂/∂x a + b = a' + b' == a'
a.value() += b;
return a;
}

/** Standard compound subtraction and assignment operator. */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff& operator-=(AutoDiff& a, const AutoDiff& b) {
// ∂/∂x a - b = a' - b'
a.partials().AddScaled(-1, b.partials());
a.value() -= b.value();
return a;
}

/** Standard compound subtraction and assignment operator. */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff& operator-=(AutoDiff& a, double b) {
// ∂/∂x a - b = a' - b' == a'
a.value() -= b;
return a;
}

/** Standard compound multiplication and assignment operator. */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff& operator*=(AutoDiff& a, const AutoDiff& b) {
// ∂/∂x a * b = ba' + ab'
a.partials().Mul(b.value());
a.partials().AddScaled(a.value(), b.partials());
a.value() *= b.value();
return a;
}

/** Standard compound multiplication and assignment operator. */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff& operator*=(AutoDiff& a, double b) {
// ∂/∂x a * b = ba' + ab' = ba'
a.partials().Mul(b);
a.value() *= b;
return a;
}

/** Standard compound division and assignment operator. */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff& operator/=(AutoDiff& a, const AutoDiff& b) {
// ∂/∂x a / b = (ba' - ab') / b²
a.partials().Mul(b.value());
a.partials().AddScaled(-a.value(), b.partials());
a.partials().Div(b.value() * b.value());
a.value() /= b.value();
return a;
}

/** Standard compound division and assignment operator. */
// NOLINTNEXTLINE(runtime/references) to match the required signature.
inline AutoDiff& operator/=(AutoDiff& a, double b) {
// ∂/∂x a / b = (ba' - ab') / b² = a'/b
a.partials().Div(b);
a.value() /= b;
return a;
}

/** Standard addition operator. */
inline AutoDiff operator+(AutoDiff a, const AutoDiff& b) {
a += b;
return a;
}

/** Standard addition operator. */
inline AutoDiff operator+(AutoDiff a, double b) {
a += b;
return a;
}

/** Standard addition operator. */
inline AutoDiff operator+(double a, AutoDiff b) {
b += a;
return b;
}

/** Standard unary plus operator. */
inline AutoDiff operator+(AutoDiff x) {
return x;
}

/** Standard subtraction operator. */
inline AutoDiff operator-(AutoDiff a, const AutoDiff& b) {
a -= b;
return a;
}

/** Standard subtraction operator. */
inline AutoDiff operator-(AutoDiff a, double b) {
a -= b;
return a;
}

/** Standard subtraction operator. */
inline AutoDiff operator-(double a, AutoDiff b) {
b *= -1;
b += a;
return b;
}

/** Standard unary minus operator. */
inline AutoDiff operator-(AutoDiff x) {
x *= -1;
return x;
}

/** Standard multiplication operator. */
inline AutoDiff operator*(AutoDiff a, const AutoDiff& b) {
a *= b;
return a;
}

/** Standard multiplication operator. */
inline AutoDiff operator*(AutoDiff a, double b) {
a *= b;
return a;
}

/** Standard multiplication operator. */
inline AutoDiff operator*(double a, AutoDiff b) {
b *= a;
return b;
}

/** Standard division operator. */
inline AutoDiff operator/(AutoDiff a, const AutoDiff& b) {
a /= b;
return a;
}

/** Standard division operator. */
inline AutoDiff operator/(AutoDiff a, double b) {
a /= b;
return a;
}

/** Standard division operator. */
inline AutoDiff operator/(double a, const AutoDiff& b) {
AutoDiff result{a};
result /= b;
return result;
}

//@}

/// @name Miscellaneous functions
//@{

Expand Down
33 changes: 33 additions & 0 deletions common/ad/test/standard_operations_add_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Addition) {
CHECK_BINARY_OP(+, x, y, 1.0);
CHECK_BINARY_OP(+, x, y, -1.0);
CHECK_BINARY_OP(+, y, x, 1.0);
CHECK_BINARY_OP(+, y, x, -1.0);
}

namespace {

// We need to wrap the operator under test, to give it a name.
// Eigen doesn't provide unary operator+, so we'll no-op instead.
AutoDiffDut unary_add(const AutoDiffDut& x) { return +x; }
AutoDiff3 unary_add(const AutoDiff3& x) { return x; }

} // namespace

TEST_F(StandardOperationsTest, UnaryAddition) {
CHECK_UNARY_FUNCTION(unary_add, x, y, 1.0);
CHECK_UNARY_FUNCTION(unary_add, x, y, -1.0);
CHECK_UNARY_FUNCTION(unary_add, y, x, 1.0);
CHECK_UNARY_FUNCTION(unary_add, y, x, -1.0);
}

} // namespace
} // namespace test
} // namespace drake
26 changes: 26 additions & 0 deletions common/ad/test/standard_operations_dec_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <sstream>

#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Postdecrement) {
AutoDiffDut x{0.25, 3, 0};
EXPECT_EQ((x--).value(), 0.25);
EXPECT_EQ(x.value(), -0.75);
EXPECT_EQ(x.derivatives()[0], 1.0);
}

TEST_F(StandardOperationsTest, Predecrement) {
AutoDiffDut x{0.25, 3, 0};
EXPECT_EQ((--x).value(), -0.75);
EXPECT_EQ(x.value(), -0.75);
EXPECT_EQ(x.derivatives()[0], 1.0);
}

} // namespace
} // namespace test
} // namespace drake
17 changes: 17 additions & 0 deletions common/ad/test/standard_operations_div_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Division) {
CHECK_BINARY_OP(/, x, y, 1.0);
CHECK_BINARY_OP(/, x, y, -1.0);
CHECK_BINARY_OP(/, y, x, 1.0);
CHECK_BINARY_OP(/, y, x, -1.0);
}

} // namespace
} // namespace test
} // namespace drake
Loading

0 comments on commit d5ce331

Please sign in to comment.