Skip to content

Commit

Permalink
Implement a visit_overloaded polyfill for C++17 on Focal
Browse files Browse the repository at this point in the history
After we drop Focal, we can revert this commit.
  • Loading branch information
jwnimmer-tri committed Jan 30, 2024
1 parent 692a4a9 commit 6b32a81
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 25 deletions.
47 changes: 45 additions & 2 deletions common/overloaded.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

#include <utility>
#include <variant>

/** @file
The "overloaded" variant-visit pattern.
Expand All @@ -8,8 +11,8 @@ doesn't support it natively. There is a commonly used two-line boilerplate
that bridges this gap; see
https://en.cppreference.com/w/cpp/utility/variant/visit
This file should be included by classes that wish to
use the variant visit pattern, i.e.
This file should be included by classes that wish to use the variant visit
pattern, i.e.:
@code {.cpp}
using MyVariant = std::variant<int, std::string>;
Expand All @@ -21,6 +24,23 @@ use the variant visit pattern, i.e.
EXPECT_EQ(result, "found an int");
@endcode
However, note that the prior example DOES NOT WORK yet in Drake. In order to
support C++17 we must also (for now) use a polyfill for `std::visit<T>` which
we've named `visit_overloaded<T>`, so within Drake we must spell it like this:
@code {.cpp}
using MyVariant = std::variant<int, std::string>;
MyVariant v = 5;
std::string result = visit_overloaded<const char*>(overloaded{
[](const int arg) { return "found an int"; },
[](const std::string& arg) { return "found a string"; }
}, v);
EXPECT_EQ(result, "found an int");
@endcode
When we drop support for C++17, we'll be able to return back to the normal
pattern.
@warning This file must never be included by a header, only by a cc file.
This is enforced by a lint rule in `tools/lint/drakelint.py`.
*/
Expand All @@ -40,4 +60,27 @@ overloaded(Ts...) -> overloaded<Ts...>;

// NOTE: The second line above can be removed when we are compiling with
// >= C++20 on all platforms.

// This is a polyfill for C++20's std::visit<Return>(visitor, variant) that we
// need while we still support C++17. Once we drop C++17 (i.e., once we drop
// Ubuntu 20.04), we should switch back to the conventional spelling and remove
// this entire block of code.
#if __cplusplus > 201703L
// On reasonable platforms, we can just call std::visit.
template <typename Return, typename Visitor, typename Variant>
auto visit_overloaded(Visitor&& visitor, Variant&& variant) -> decltype(auto) {
return std::visit<Return>(std::forward<Visitor>(visitor),
std::forward<Variant>(variant));
}
#else
// On Focal, we need to do a polyfill.
template <typename Return, typename Visitor, typename Variant>
auto visit_overloaded(Visitor&& visitor, Variant&& variant) -> Return {
auto visitor_coerced = [&visitor]<typename Value>(Value&& value) -> Return {
return visitor(std::forward<Value>(value));
};
return std::visit(visitor_coerced, std::forward<Variant>(variant));
}
#endif

} // namespace
4 changes: 2 additions & 2 deletions common/schema/rotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Rotation::Rotation(const math::RollPitchYaw<double>& arg) {
}

bool Rotation::IsDeterministic() const {
return std::visit<bool>(overloaded{
return visit_overloaded<bool>(overloaded{
[](const Identity&) {
return true;
},
Expand Down Expand Up @@ -65,7 +65,7 @@ Vector<Expression, Size> deg2rad(

math::RotationMatrix<Expression> Rotation::ToSymbolic() const {
using Result = math::RotationMatrix<Expression>;
return std::visit<Result>(overloaded{
return visit_overloaded<Result>(overloaded{
[](const Identity&) {
return Result{};
},
Expand Down
6 changes: 3 additions & 3 deletions common/schema/stochastic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ drake::VectorX<Expression> ToSymbolic(
}

bool IsDeterministic(const DistributionVariant& var) {
return std::visit<bool>(overloaded{
return visit_overloaded<bool>(overloaded{
[](const double&) {
return true;
},
Expand Down Expand Up @@ -350,7 +350,7 @@ drake::VectorX<Expression> UniformVector<Size>::ToSymbolic() const {
template <int Size>
unique_ptr<DistributionVector> ToDistributionVector(
const DistributionVectorVariant<Size>& vec) {
return std::visit<unique_ptr<DistributionVector>>(overloaded{
return visit_overloaded<unique_ptr<DistributionVector>>(overloaded{
// NOLINTNEXTLINE(whitespace/line_length)
[](const drake::Vector<double, Size>& arg) {
return std::make_unique<DeterministicVector<Size>>(arg);
Expand Down Expand Up @@ -387,7 +387,7 @@ unique_ptr<DistributionVector> ToDistributionVector(

template <int Size>
bool IsDeterministic(const DistributionVectorVariant<Size>& vec) {
return std::visit<bool>(overloaded{
return visit_overloaded<bool>(overloaded{
[](const drake::Vector<double, Size>&) {
return true;
},
Expand Down
6 changes: 3 additions & 3 deletions common/test/overloaded_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ GTEST_TEST(OverloadedTest, CommentExampleTest) {
// Test the exact text of the example in the file comment.
using MyVariant = std::variant<int, std::string>;
MyVariant v = 5;
std::string result = std::visit<const char*>(overloaded{
std::string result = visit_overloaded<const char*>(overloaded{
[](const int arg) { return "found an int"; },
[](const std::string& arg) { return "found a string"; }
}, v);
Expand All @@ -24,7 +24,7 @@ GTEST_TEST(OverloadedTest, AutoTest) {

// An 'auto' arm doesn't match if there's any explicit match,
// no matter if it's earlier or later in the list.
std::string result = std::visit<const char*>(overloaded{
std::string result = visit_overloaded<const char*>(overloaded{
[](const auto arg) { return "found an auto"; },
[](const int arg) { return "found an int"; },
[](const std::string& arg) { return "found a string"; },
Expand All @@ -33,7 +33,7 @@ GTEST_TEST(OverloadedTest, AutoTest) {
EXPECT_EQ(result, "found an int");

// An 'auto' arm matches if there's no explicit match.
result = std::visit<const char*>(overloaded{
result = visit_overloaded<const char*>(overloaded{
[](const auto arg) { return "found an auto"; },
[](const std::string& arg) { return "found a string"; },
}, v);
Expand Down
18 changes: 9 additions & 9 deletions common/yaml/yaml_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Node Node::MakeNull() {
}

NodeType Node::GetType() const {
return std::visit<NodeType>( // BR
return visit_overloaded<NodeType>( // BR
overloaded{
[](const ScalarData&) {
return NodeType::kScalar;
Expand Down Expand Up @@ -130,7 +130,7 @@ bool operator==(const Node::MappingData& a, const Node::MappingData& b) {
}

std::string_view Node::GetTag() const {
return std::visit<std::string_view>( // BR
return visit_overloaded<std::string_view>( // BR
overloaded{
[](const std::string& tag) {
return std::string_view{tag};
Expand Down Expand Up @@ -181,7 +181,7 @@ const std::optional<Node::Mark>& Node::GetMark() const {
}

const std::string& Node::GetScalar() const {
return *std::visit<const std::string*>(
return *visit_overloaded<const std::string*>(
overloaded{
[](const ScalarData& data) {
return &data.scalar;
Expand All @@ -195,7 +195,7 @@ const std::string& Node::GetScalar() const {
}

const std::vector<Node>& Node::GetSequence() const {
return *std::visit<const std::vector<Node>*>(
return *visit_overloaded<const std::vector<Node>*>(
overloaded{
[](const SequenceData& data) {
return &data.sequence;
Expand All @@ -209,7 +209,7 @@ const std::vector<Node>& Node::GetSequence() const {
}

void Node::Add(Node value) {
return std::visit<void>(
return visit_overloaded<void>(
overloaded{
[&value](SequenceData& data) {
data.sequence.push_back(std::move(value));
Expand All @@ -223,7 +223,7 @@ void Node::Add(Node value) {
}

const std::map<std::string, Node>& Node::GetMapping() const {
return *std::visit<const std::map<std::string, Node>*>(
return *visit_overloaded<const std::map<std::string, Node>*>(
overloaded{
[](const MappingData& data) {
return &data.mapping;
Expand All @@ -237,7 +237,7 @@ const std::map<std::string, Node>& Node::GetMapping() const {
}

void Node::Add(std::string key, Node value) {
return std::visit<void>(
return visit_overloaded<void>(
overloaded{
[&key, &value](MappingData& data) {
const auto result =
Expand All @@ -263,7 +263,7 @@ void Node::Add(std::string key, Node value) {
}

Node& Node::At(std::string_view key) {
return *std::visit<Node*>(
return *visit_overloaded<Node*>(
overloaded{
[key](MappingData& data) {
return &data.mapping.at(std::string{key});
Expand All @@ -277,7 +277,7 @@ Node& Node::At(std::string_view key) {
}

void Node::Remove(std::string_view key) {
return std::visit<void>(
return visit_overloaded<void>(
overloaded{
[key](MappingData& data) {
auto erased = data.mapping.erase(std::string{key});
Expand Down
2 changes: 1 addition & 1 deletion geometry/meshcat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ class MeshcatShapeReifier : public ShapeReifier {
}

// Set the scale.
std::visit<void>(
visit_overloaded<void>(
overloaded{[](std::monostate) {},
[scale](auto& lumped_object) {
Eigen::Map<Eigen::Matrix4d> matrix(lumped_object.matrix);
Expand Down
2 changes: 1 addition & 1 deletion geometry/meshcat_visualizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ void MeshcatVisualizer<T>::SetObjects(
if constexpr (std::is_same_v<T, double>) {
if (params_.show_hydroelastic) {
auto maybe_mesh = inspector.maybe_get_hydroelastic_mesh(geom_id);
std::visit<void>(
visit_overloaded<void>(
overloaded{[](std::monostate) {},
[&](const TriangleSurfaceMesh<double>* mesh) {
DRAKE_DEMAND(mesh != nullptr);
Expand Down
2 changes: 1 addition & 1 deletion geometry/render_gl/test/multithread_safety_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void AddShape(const Shape& shape, const Vector3d& p_WS,
std::variant<Rgba, std::string> diffuse, RenderEngine* engine) {
PerceptionProperties material;
material.AddProperty("label", "id", render::RenderLabel::kDontCare);
std::visit<void>( // BR
visit_overloaded<void>( // BR
overloaded{[&material](const Rgba& rgba) {
material.AddProperty("phong", "diffuse", rgba);
},
Expand Down
2 changes: 1 addition & 1 deletion multibody/meshcat/joint_sliders.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ std::map<int, std::string> GetPositionNames(
VectorXd Broadcast(
const char* diagnostic_name, double default_value, int num_positions,
std::variant<std::monostate, double, VectorXd> value) {
return std::visit<VectorXd>(overloaded{
return visit_overloaded<VectorXd>(overloaded{
[num_positions, default_value](std::monostate) {
return VectorXd::Constant(num_positions, default_value);
},
Expand Down
4 changes: 2 additions & 2 deletions systems/sensors/camera_config_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void ValidateEngineAndMaybeAdd(const CameraConfig& config,
bool already_exists = !type_name.empty();

if (already_exists) {
std::visit<void>(
visit_overloaded<void>(
overloaded{
[&type_name, &config](const std::string& class_name) {
if (!class_name.empty() &&
Expand All @@ -152,7 +152,7 @@ void ValidateEngineAndMaybeAdd(const CameraConfig& config,

if (already_exists) return;

std::visit<void>(
visit_overloaded<void>(
overloaded{
[&config, scene_graph](const std::string& class_name) {
MakeEngineByClassName(class_name, config, scene_graph);
Expand Down

0 comments on commit 6b32a81

Please sign in to comment.