Skip to content

Commit

Permalink
[ad] Add support for trig and hyperbolic operations (RobotLocomotion#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jwnimmer-tri authored Aug 2, 2022
1 parent ff2554b commit 6fe3aba
Show file tree
Hide file tree
Showing 14 changed files with 394 additions and 7 deletions.
10 changes: 10 additions & 0 deletions common/ad/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,27 @@ drake_cc_googletest(
srcs = [
"test/standard_operations_abs2_test.cc",
"test/standard_operations_abs_test.cc",
"test/standard_operations_acos_test.cc",
"test/standard_operations_add_test.cc",
"test/standard_operations_asin_test.cc",
"test/standard_operations_atan2_test.cc",
"test/standard_operations_atan_test.cc",
"test/standard_operations_cmp_test.cc",
"test/standard_operations_cos_test.cc",
"test/standard_operations_cosh_test.cc",
"test/standard_operations_dec_test.cc",
"test/standard_operations_div_test.cc",
"test/standard_operations_inc_test.cc",
"test/standard_operations_integer_test.cc",
"test/standard_operations_max_test.cc",
"test/standard_operations_min_test.cc",
"test/standard_operations_mul_test.cc",
"test/standard_operations_sin_test.cc",
"test/standard_operations_sinh_test.cc",
"test/standard_operations_stream_test.cc",
"test/standard_operations_sub_test.cc",
"test/standard_operations_tan_test.cc",
"test/standard_operations_tanh_test.cc",
],
copts = [
# The test fixture at requires some configuration for the specific
Expand Down
131 changes: 131 additions & 0 deletions common/ad/internal/standard_operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,137 @@
namespace drake {
namespace ad {

AutoDiff sin(AutoDiff x) {
const double new_value = std::sin(x.value());
// ∂/∂x sin(x) = cos(x)
// The domain of sin & cos are identical; no need for special cases here.
x.partials().Mul(std::cos(x.value()));
x.value() = new_value;
return x;
}

AutoDiff cos(AutoDiff x) {
const double new_value = std::cos(x.value());
// ∂/∂x cos(x) = -sin(x)
// The domain of cos & sin are identical; no need for special cases here.
x.partials().Mul(-std::sin(x.value()));
x.value() = new_value;
return x;
}

AutoDiff tan(AutoDiff x) {
const double new_value = std::tan(x.value());
// ∂/∂x tan(x) = sec²(x) = cos⁻²(x)
// The mathematical tan() function has poles at (½ + n)π; however no common
// floating-point representation is able to represent π/2 exactly, so there is
// no argument value for which a pole error occurs, and so the domain of
// std::tan & std::cos are identical (only non-finite values are disallowed).
const double cos_x = std::cos(x.value());
x.partials().Div(cos_x * cos_x);
x.value() = new_value;
return x;
}

AutoDiff asin(AutoDiff x) {
const double new_value = std::asin(x.value());
// ∂/∂x asin(x) = 1 / sqrt(1 - x²)
// The domain of asin is [-1, 1], which is the same as sqrt(1-x²); no need
// for special cases here.
const double x2 = x.value() * x.value();
x.partials().Div(std::sqrt(1 - x2));
x.value() = new_value;
return x;
}

AutoDiff acos(AutoDiff x) {
const double new_value = std::acos(x.value());
// ∂/∂x acos(x) = -1 / sqrt(1 - x²)
// The domain of acos is [-1, 1], which is the same as sqrt(1-x²); no need
// for special cases here.
const double x2 = x.value() * x.value();
x.partials().Div(-std::sqrt(1 - x2));
x.value() = new_value;
return x;
}

AutoDiff atan(AutoDiff x) {
const double new_value = std::atan(x.value());
// ∂/∂x atan(x) = 1 / (1 + x²)
// The domain of atan includes everything except NaN, which will propagate
// automatically via 1 + x²; no need for special cases here.
const double x2 = x.value() * x.value();
x.partials().Div(1 + x2);
x.value() = new_value;
return x;
}

AutoDiff atan2(AutoDiff a, const AutoDiff& b) {
const double new_value = std::atan2(a.value(), b.value());
// ∂/∂x atan2(a, b) = (ba' - ab')/(a² + b²)
// The domain of atan2 includes everything except NaN, which will propagate
// automatically via `norm`.
// TODO(jwnimmer-tri) Handle the IEEE special cases for ±∞ and ±0 as input(s).
// Figure out the proper gradients in that case.
const double norm = a.value() * a.value() + b.value() * b.value();
a.partials().Mul(b.value());
a.partials().AddScaled(-a.value(), b.partials());
a.partials().Div(norm);
a.value() = new_value;
return a;
}

AutoDiff atan2(AutoDiff a, double b) {
const double new_value = std::atan2(a.value(), b);
// ∂/∂x atan2(a, b) = (ba' - ab')/(a² + b²) = ba'/(a² + b²)
// The domain of atan2 includes everything except NaN, which will propagate
// automatically via `norm`.
// TODO(jwnimmer-tri) Handle the IEEE special cases for ±∞ and ±0 as input(s).
// Figure out the proper gradients in that case.
const double norm = a.value() * a.value() + b * b;
a.partials().Mul(b / norm);
a.value() = new_value;
return a;
}

AutoDiff atan2(double a, AutoDiff b) {
const double new_value = std::atan2(a, b.value());
// ∂/∂x atan2(a, b) = (ba' - ab')/(a² + b²) = -ab'/(a² + b²)
// The domain of atan2 includes everything except NaN, which will propagate
// automatically via `norm`.
// TODO(jwnimmer-tri) Handle the IEEE special cases for ±∞ and ±0 as input(s).
// Figure out the proper gradients in that case.
const double norm = a * a + b.value() * b.value();
b.partials().Mul(-a / norm);
b.value() = new_value;
return b;
}

AutoDiff sinh(AutoDiff x) {
const double new_value = std::sinh(x.value());
// ∂/∂x sinh(x) = cosh(x)
// The domain of sinh & cosh are identical; no need for special cases here.
x.partials().Mul(std::cosh(x.value()));
x.value() = new_value;
return x;
}

AutoDiff cosh(AutoDiff x) {
const double new_value = std::cosh(x.value());
// ∂/∂x cosh(x) = sinh(x)
// The domain of cosh & sinh are identical; no need for special cases here.
x.partials().Mul(std::sinh(x.value()));
x.value() = new_value;
return x;
}

AutoDiff tanh(AutoDiff x) {
const double new_value = std::tanh(x.value());
// ∂/∂x tanh(x) = 1 - tanh²(x)
x.partials().Mul(1 - (new_value * new_value));
x.value() = new_value;
return x;
}

AutoDiff ceil(AutoDiff x) {
x.value() = std::ceil(x.value());
x.partials().SetZero();
Expand Down
50 changes: 50 additions & 0 deletions common/ad/internal/standard_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,56 @@ inline AutoDiff abs2(AutoDiff x) {

//@}

/// @name Math functions: Trigonometric functions
///
/// https://en.cppreference.com/w/cpp/numeric/math#Trigonometric_functions
//@{

/** ADL overload to mimic std::sin from <cmath>. */
AutoDiff sin(AutoDiff x);

/** ADL overload to mimic std::cos from <cmath>. */
AutoDiff cos(AutoDiff x);

/** ADL overload to mimic std::tan from <cmath>. */
AutoDiff tan(AutoDiff x);

/** ADL overload to mimic std::asin from <cmath>. */
AutoDiff asin(AutoDiff x);

/** ADL overload to mimic std::acos from <cmath>. */
AutoDiff acos(AutoDiff x);

/** ADL overload to mimic std::atan from <cmath>. */
AutoDiff atan(AutoDiff x);

/** ADL overload to mimic std::atan2 from <cmath>. */
AutoDiff atan2(AutoDiff a, const AutoDiff& b);

/** ADL overload to mimic std::atan2 from <cmath>. */
AutoDiff atan2(AutoDiff a, double b);

/** ADL overload to mimic std::atan2 from <cmath>. */
AutoDiff atan2(double a, AutoDiff b);

//@}

/// @name Math functions: Hyperbolic functions
///
/// https://en.cppreference.com/w/cpp/numeric/math#Hyperbolic_functions
//@{

/** ADL overload to mimic std::sinh from <cmath>. */
AutoDiff sinh(AutoDiff x);

/** ADL overload to mimic std::cosh from <cmath>. */
AutoDiff cosh(AutoDiff x);

/** ADL overload to mimic std::tanh from <cmath>. */
AutoDiff tanh(AutoDiff x);

//@}

/// @name Math functions: Nearest integer floating point operations
///
/// https://en.cppreference.com/w/cpp/numeric/math#Nearest_integer_floating_point_operations
Expand Down
17 changes: 17 additions & 0 deletions common/ad/test/standard_operations_acos_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Acos) {
CHECK_UNARY_FUNCTION(acos, x, y, 0.1);
CHECK_UNARY_FUNCTION(acos, x, y, -0.1);
CHECK_UNARY_FUNCTION(acos, y, x, 0.1);
CHECK_UNARY_FUNCTION(acos, y, x, -0.1);
}

} // namespace
} // namespace test
} // namespace drake
17 changes: 17 additions & 0 deletions common/ad/test/standard_operations_asin_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Asin) {
CHECK_UNARY_FUNCTION(asin, x, y, 0.1);
CHECK_UNARY_FUNCTION(asin, x, y, -0.1);
CHECK_UNARY_FUNCTION(asin, y, x, 0.1);
CHECK_UNARY_FUNCTION(asin, y, x, -0.1);
}

} // namespace
} // namespace test
} // namespace drake
40 changes: 40 additions & 0 deletions common/ad/test/standard_operations_atan2_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

// Eigen doesn't provide mixed-scalar atan2() overloads, so we need to do it.
// This assumes the standard_operations_test.h is implemented using AutoDiff3.
AutoDiff3 atan2(const AutoDiff3& y, double x) {
return atan2(y, AutoDiff3{x});
}
AutoDiff3 atan2(double y, const AutoDiff3& x) {
return atan2(AutoDiff3{y}, x);
}

TEST_F(StandardOperationsTest, Atan2AdsAds) {
CHECK_BINARY_FUNCTION_ADS_ADS(atan2, x, y, 0.1);
CHECK_BINARY_FUNCTION_ADS_ADS(atan2, x, y, -0.1);
CHECK_BINARY_FUNCTION_ADS_ADS(atan2, y, x, 0.4);
CHECK_BINARY_FUNCTION_ADS_ADS(atan2, y, x, -0.4);
}

TEST_F(StandardOperationsTest, Atan2AdsDouble) {
CHECK_BINARY_FUNCTION_ADS_SCALAR(atan2, x, y, 0.1);
CHECK_BINARY_FUNCTION_ADS_SCALAR(atan2, x, y, -0.1);
CHECK_BINARY_FUNCTION_ADS_SCALAR(atan2, y, x, 0.4);
CHECK_BINARY_FUNCTION_ADS_SCALAR(atan2, y, x, -0.4);
}

TEST_F(StandardOperationsTest, Atan2DoubleAds) {
CHECK_BINARY_FUNCTION_SCALAR_ADS(atan2, x, y, 0.1);
CHECK_BINARY_FUNCTION_SCALAR_ADS(atan2, x, y, -0.1);
CHECK_BINARY_FUNCTION_SCALAR_ADS(atan2, y, x, 0.4);
CHECK_BINARY_FUNCTION_SCALAR_ADS(atan2, y, x, -0.4);
}

} // namespace
} // namespace test
} // namespace drake
26 changes: 26 additions & 0 deletions common/ad/test/standard_operations_atan_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

// Eigen doesn't provide an atan() overload, so we need to do it.
// This assumes the standard_operations_test.h is implemented using AutoDiff3.
AutoDiff3 atan(const AutoDiff3& x) {
// ∂/∂x atan(x) = 1 / (1 + x²)
return AutoDiff3{
std::atan(x.value()),
x.derivatives() / (1 + x.value() * x.value())};
}

TEST_F(StandardOperationsTest, Atan) {
CHECK_UNARY_FUNCTION(atan, x, y, 0.1);
CHECK_UNARY_FUNCTION(atan, x, y, -0.1);
CHECK_UNARY_FUNCTION(atan, y, x, 0.1);
CHECK_UNARY_FUNCTION(atan, y, x, -0.1);
}

} // namespace
} // namespace test
} // namespace drake
17 changes: 17 additions & 0 deletions common/ad/test/standard_operations_cos_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Cos) {
CHECK_UNARY_FUNCTION(cos, x, y, 0.1);
CHECK_UNARY_FUNCTION(cos, x, y, -0.1);
CHECK_UNARY_FUNCTION(cos, y, x, 0.1);
CHECK_UNARY_FUNCTION(cos, y, x, -0.1);
}

} // namespace
} // namespace test
} // namespace drake
17 changes: 17 additions & 0 deletions common/ad/test/standard_operations_cosh_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Cosh) {
CHECK_UNARY_FUNCTION(cosh, x, y, 0.1);
CHECK_UNARY_FUNCTION(cosh, x, y, -0.1);
CHECK_UNARY_FUNCTION(cosh, y, x, 0.1);
CHECK_UNARY_FUNCTION(cosh, y, x, -0.1);
}

} // namespace
} // namespace test
} // namespace drake
17 changes: 17 additions & 0 deletions common/ad/test/standard_operations_sin_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Sin) {
CHECK_UNARY_FUNCTION(sin, x, y, 0.1);
CHECK_UNARY_FUNCTION(sin, x, y, -0.1);
CHECK_UNARY_FUNCTION(sin, y, x, 0.1);
CHECK_UNARY_FUNCTION(sin, y, x, -0.1);
}

} // namespace
} // namespace test
} // namespace drake
17 changes: 17 additions & 0 deletions common/ad/test/standard_operations_sinh_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "drake/common/ad/auto_diff.h"
#include "drake/common/ad/test/standard_operations_test.h"

namespace drake {
namespace test {
namespace {

TEST_F(StandardOperationsTest, Sinh) {
CHECK_UNARY_FUNCTION(sinh, x, y, 0.1);
CHECK_UNARY_FUNCTION(sinh, x, y, -0.1);
CHECK_UNARY_FUNCTION(sinh, y, x, 0.1);
CHECK_UNARY_FUNCTION(sinh, y, x, -0.1);
}

} // namespace
} // namespace test
} // namespace drake
Loading

0 comments on commit 6fe3aba

Please sign in to comment.