Skip to content

Commit

Permalink
Merge pull request RobotLocomotion#3151 from jwnimmer-tri/basic-state…
Browse files Browse the repository at this point in the history
…-and-output-vector

Add vector base type that serves as both State and Output
  • Loading branch information
david-german-tri authored Aug 16, 2016
2 parents 103c984 + 15720b5 commit 9bee5df
Show file tree
Hide file tree
Showing 14 changed files with 227 additions and 78 deletions.
30 changes: 6 additions & 24 deletions drake/examples/spring_mass/spring_mass_system.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ constexpr int kStateSize = 3; // position, velocity, power integral

SpringMassStateVector::SpringMassStateVector(double initial_position,
double initial_velocity)
: BasicStateVector<double>(kStateSize) {
: BasicStateAndOutputVector<double>(kStateSize) {
set_position(initial_position);
set_velocity(initial_velocity);
set_conservative_work(0);
}

SpringMassStateVector::SpringMassStateVector()
: SpringMassStateVector(0.0, 0.0) {}

SpringMassStateVector::~SpringMassStateVector() {}

// Order matters: Position (q) precedes velocity (v) precedes misc. (z) in
Expand All @@ -47,27 +50,6 @@ SpringMassStateVector* SpringMassStateVector::DoClone() const {
return state;
}

SpringMassOutputVector::SpringMassOutputVector()
: BasicVector<double>(kStateSize - 1) {
} // don't output conservative energy

double SpringMassOutputVector::get_position() const { return get_value()[0]; }
double SpringMassOutputVector::get_velocity() const { return get_value()[1]; }

void SpringMassOutputVector::set_position(double q) {
get_mutable_value()[0] = q;
}

void SpringMassOutputVector::set_velocity(double v) {
get_mutable_value()[1] = v;
}

SpringMassOutputVector* SpringMassOutputVector::DoClone() const {
SpringMassOutputVector* clone(new SpringMassOutputVector());
clone->get_mutable_value() = get_value();
return clone;
}

SpringMassSystem::SpringMassSystem(double spring_constant_N_per_m,
double mass_kg, bool system_is_forced)
: spring_constant_N_per_m_(spring_constant_N_per_m),
Expand Down Expand Up @@ -120,7 +102,7 @@ std::unique_ptr<SystemOutput<double>> SpringMassSystem::AllocateOutput(
std::unique_ptr<LeafSystemOutput<double>> output(
new LeafSystemOutput<double>);
{
std::unique_ptr<VectorInterface<double>> data(new SpringMassOutputVector());
std::unique_ptr<VectorInterface<double>> data(new SpringMassStateVector());
std::unique_ptr<OutputPort<double>> port(
new OutputPort<double>(std::move(data)));
output->get_mutable_ports()->push_back(std::move(port));
Expand All @@ -142,7 +124,7 @@ void SpringMassSystem::EvalOutput(const ContextBase<double>& context,
SystemOutput<double>* output) const {
// TODO(david-german-tri): Cache the output of this function.
const SpringMassStateVector& state = get_state(context);
SpringMassOutputVector* output_vector = get_mutable_output(output);
SpringMassStateVector* output_vector = get_mutable_output(output);
output_vector->set_position(state.get_position());
output_vector->set_velocity(state.get_velocity());
}
Expand Down
41 changes: 8 additions & 33 deletions drake/examples/spring_mass/spring_mass_system.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
#include <string>

#include "drake/drakeSpringMassSystem_export.h"
#include "drake/systems/framework/basic_state_vector.h"
#include "drake/systems/framework/basic_vector.h"
#include "drake/systems/framework/basic_state_and_output_vector.h"
#include "drake/systems/framework/context.h"
#include "drake/systems/framework/state_vector.h"
#include "drake/systems/framework/system.h"
Expand All @@ -17,11 +16,13 @@ namespace examples {
/// The state of a one-dimensional spring-mass system, consisting of the
/// position and velocity of the mass, in meters and meters/s.
class DRAKESPRINGMASSSYSTEM_EXPORT SpringMassStateVector
: public systems::BasicStateVector<double> {
: public systems::BasicStateAndOutputVector<double> {
public:
/// @param initial_position The position of the mass in meters.
/// @param initial_velocity The velocity of the mass in meters / second.
SpringMassStateVector(double initial_position, double initial_velocity);
/// Creates a state with position and velocity set to zero.
SpringMassStateVector();
~SpringMassStateVector() override;

/// Returns the position of the mass in meters, where zero is the point
Expand All @@ -47,32 +48,6 @@ class DRAKESPRINGMASSSYSTEM_EXPORT SpringMassStateVector
SpringMassStateVector* DoClone() const override;
};

/// The output of a one-dimensional spring-mass system, consisting of the
/// position and velocity of the mass, in meters. Note that although this
/// system tracks work done as a state variable, we are not reporting that
/// as an Output.
class DRAKESPRINGMASSSYSTEM_EXPORT SpringMassOutputVector
: public systems::BasicVector<double> {
public:
SpringMassOutputVector();

/// Returns the position of the mass in meters, where zero is the point
/// where the spring exerts no force.
double get_position() const;

/// Sets the position of the mass in meters.
void set_position(double q);

/// Returns the velocity of the mass in meters per second.
double get_velocity() const;

/// Sets the velocity of the mass in meters per second.
void set_velocity(double v);

private:
SpringMassOutputVector* DoClone() const override;
};

/// A model of a one-dimensional spring-mass system.
///
/// @verbatim
Expand Down Expand Up @@ -243,13 +218,13 @@ class DRAKESPRINGMASSSYSTEM_EXPORT SpringMassSystem
return dynamic_cast<SpringMassStateVector*>(cstate->get_mutable_state());
}

static const SpringMassOutputVector& get_output(const MyOutput& output) {
return dynamic_cast<const SpringMassOutputVector&>(
static const SpringMassStateVector& get_output(const MyOutput& output) {
return dynamic_cast<const SpringMassStateVector&>(
*output.get_port(0).get_vector_data());
}

static SpringMassOutputVector* get_mutable_output(MyOutput* output) {
return dynamic_cast<SpringMassOutputVector*>(
static SpringMassStateVector* get_mutable_output(MyOutput* output) {
return dynamic_cast<SpringMassStateVector*>(
output->get_mutable_port(0)->GetMutableVectorData());
}

Expand Down
12 changes: 6 additions & 6 deletions drake/examples/spring_mass/test/spring_mass_system_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class SpringMassSystemTest : public ::testing::Test {
// Set up some convenience pointers.
state_ = dynamic_cast<SpringMassStateVector*>(
context_->get_mutable_state()->continuous_state->get_mutable_state());
output_ = dynamic_cast<const SpringMassOutputVector*>(
output_ = dynamic_cast<const SpringMassStateVector*>(
system_output_->get_port(0).get_vector_data());
derivatives_ = dynamic_cast<SpringMassStateVector*>(
system_derivatives_->get_mutable_state());
Expand Down Expand Up @@ -97,7 +97,7 @@ class SpringMassSystemTest : public ::testing::Test {
std::unique_ptr<BasicStateVector<double>> configuration_derivatives_;

SpringMassStateVector* state_;
const SpringMassOutputVector* output_;
const SpringMassStateVector* output_;
SpringMassStateVector* derivatives_;

private:
Expand Down Expand Up @@ -127,10 +127,10 @@ TEST_F(SpringMassSystemTest, CloneState) {
TEST_F(SpringMassSystemTest, CloneOutput) {
InitializeState(1.0, 2.0);
system_->EvalOutput(*context_, system_output_.get());
std::unique_ptr<VectorInterface<double>> clone = output_->Clone();
std::unique_ptr<VectorInterface<double>> clone = output_->CloneVector();

SpringMassOutputVector* typed_clone =
dynamic_cast<SpringMassOutputVector*>(clone.get());
SpringMassStateVector* typed_clone =
dynamic_cast<SpringMassStateVector*>(clone.get());
EXPECT_EQ(1.0, typed_clone->get_position());
EXPECT_EQ(2.0, typed_clone->get_velocity());
}
Expand All @@ -147,7 +147,7 @@ TEST_F(SpringMassSystemTest, Output) {
EXPECT_EQ(0.25, output_->get_velocity());

// Check the output through the VectorInterface API.
ASSERT_EQ(2, output_->size());
ASSERT_EQ(3, output_->size());
EXPECT_NEAR(0.1, output_->get_value()[0], 1e-14);
EXPECT_EQ(0.25, output_->get_value()[1]);
}
Expand Down
2 changes: 2 additions & 0 deletions drake/systems/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Source files used to build drakeSystemFramework.
set(sources
basic_state_and_output_vector.cc
basic_state_vector.cc
basic_vector.cc
cache.cc
Expand Down Expand Up @@ -32,6 +33,7 @@ set(sources
# System2 framework template implementations have been moved into .cc files
# using explicit instantiation, tighten this list.
set(installed_headers
basic_state_and_output_vector.h
basic_state_vector.h
basic_vector.h
cache.h
Expand Down
4 changes: 4 additions & 0 deletions drake/systems/framework/basic_state_and_output_vector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "drake/systems/framework/basic_state_and_output_vector.h"

// Catch compile errors early, by forcing an instantiation.
template class drake::systems::BasicStateAndOutputVector<double>;
73 changes: 73 additions & 0 deletions drake/systems/framework/basic_state_and_output_vector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#pragma once

#include <memory>
#include <vector>

#include <Eigen/Dense>

#include "drake/common/eigen_types.h"
#include "drake/systems/framework/basic_state_vector.h"
#include "drake/systems/framework/vector_interface.h"

namespace drake {
namespace systems {

/// BasicStateAndOutputVector is a concrete class template that implements
/// StateVector in a convenient manner for LeafSystem blocks, and implements
/// VectorInterface so that it may also be used as an output.
///
/// @tparam T A mathematical type compatible with Eigen's Scalar.
template <typename T>
class BasicStateAndOutputVector : public BasicStateVector<T>,
public VectorInterface<T> {
public:
/// Constructs a BasicStateAndOutputVector of the specified @p size.
explicit BasicStateAndOutputVector(int size) : BasicStateVector<T>(size) {}

/// Constructs a BasicStateAndOutputVector with the specified @p data.
explicit BasicStateAndOutputVector(const std::vector<T>& data)
: BasicStateVector<T>(data) {}

/// Constructs a BasicStateAndOutputVector that owns an arbitrary @p vector,
/// which must not be nullptr.
explicit BasicStateAndOutputVector(std::unique_ptr<VectorInterface<T>> vector)
: BasicStateVector<T>(std::move(vector)) {}

// The size() method overrides both BasicStateVector and VectorInterface.
int size() const override { return this->get_wrapped_vector().size(); }

// These VectorInterface overrides merely delegate to the wrapped object.
void set_value(const Eigen::Ref<const VectorX<T>>& value) override {
this->get_wrapped_vector().set_value(value);
}
Eigen::VectorBlock<const VectorX<T>> get_value() const override {
return this->get_wrapped_vector().get_value();
}
Eigen::VectorBlock<VectorX<T>> get_mutable_value() override {
return this->get_wrapped_vector().get_mutable_value();
}

// This VectorInterface override must not delegate, because we need to
// maintain our class type (BasicStateAndOutputVector) during cloning.
std::unique_ptr<VectorInterface<T>> CloneVector() const override {
return std::unique_ptr<VectorInterface<T>>(DoClone());
}

protected:
BasicStateAndOutputVector(const BasicStateAndOutputVector& other)
: BasicStateVector<T>(other) {}

BasicStateAndOutputVector<T>* DoClone() const override {
return new BasicStateAndOutputVector<T>(*this);
}

private:
// Disable these, for consistency with parent class.
BasicStateAndOutputVector& operator=(const BasicStateAndOutputVector&) =
delete;
BasicStateAndOutputVector(BasicStateAndOutputVector&&) = delete;
BasicStateAndOutputVector& operator=(BasicStateAndOutputVector&&) = delete;
};

} // namespace systems
} // namespace drake
17 changes: 10 additions & 7 deletions drake/systems/framework/basic_state_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
#include "drake/systems/framework/basic_vector.h"
#include "drake/systems/framework/leaf_state_vector.h"
#include "drake/systems/framework/vector_interface.h"
#include "leaf_state_vector.h"

namespace drake {
namespace systems {

/// BasicStateVector is a concrete class template that implements
/// StateVector in a convenient manner for leaf Systems,
/// StateVector in a convenient manner for LeafSystem blocks,
/// by owning and wrapping a VectorInterface<T>.
///
/// It will often be convenient to inherit from BasicStateVector, and add
/// additional semantics specific to the leaf System. Such child classes must
/// additional semantics specific to the LeafSystem. Such child classes must
/// override DoClone with an implementation that returns their concrete type.
///
/// @tparam T A mathematical type compatible with Eigen's Scalar.
Expand Down Expand Up @@ -88,16 +87,20 @@ class BasicStateVector : public LeafStateVector<T> {
}

protected:
// Clone other's wrapped vector, in case is it not a BasicVector.
BasicStateVector(const BasicStateVector& other)
: BasicStateVector(other.size()) {
SetFromVector(other.vector_->get_value());
}
: BasicStateVector(other.vector_->CloneVector()) {}

private:
BasicStateVector<T>* DoClone() const override {
return new BasicStateVector<T>(*this);
}

/// Returns a mutable reference to the underlying VectorInterface.
VectorInterface<T>& get_wrapped_vector() { return *vector_; }
/// Returns a const reference to the underlying VectorInterface.
const VectorInterface<T>& get_wrapped_vector() const { return *vector_; }

private:
// Assignment of BasicStateVectors could change size, so we forbid it.
BasicStateVector& operator=(const BasicStateVector& other) = delete;

Expand Down
2 changes: 1 addition & 1 deletion drake/systems/framework/basic_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class BasicVector : public VectorInterface<T> {
///
/// Uses the Non-Virtual Interface idiom because smart pointers do not have
/// type covariance.
std::unique_ptr<VectorInterface<T>> Clone() const final {
std::unique_ptr<VectorInterface<T>> CloneVector() const final {
return std::unique_ptr<VectorInterface<T>>(DoClone());
}

Expand Down
3 changes: 2 additions & 1 deletion drake/systems/framework/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class Context : public ContextBase<T> {
context->inputs_.emplace_back(nullptr);
} else {
context->inputs_.emplace_back(
new FreestandingInputPort<T>(port->get_vector_data()->Clone()));
new FreestandingInputPort<T>(
port->get_vector_data()->CloneVector()));
}
}

Expand Down
2 changes: 1 addition & 1 deletion drake/systems/framework/system_output.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class OutputPort {
/// Returns a clone of this OutputPort containing a clone of the data, but
/// without any dependents.
std::unique_ptr<OutputPort<T>> Clone() const {
return std::make_unique<OutputPort<T>>(vector_data_->Clone());
return std::make_unique<OutputPort<T>>(vector_data_->CloneVector());
}

private:
Expand Down
7 changes: 7 additions & 0 deletions drake/systems/framework/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ target_link_libraries(basic_state_vector_test drakeSystemFramework
${GTEST_BOTH_LIBRARIES})
add_test(NAME basic_state_vector_test COMMAND basic_state_vector_test)

add_executable(basic_state_and_output_vector_test
basic_state_and_output_vector_test.cc)
target_link_libraries(basic_state_and_output_vector_test drakeSystemFramework
${GTEST_BOTH_LIBRARIES})
add_test(NAME basic_state_and_output_vector_test
COMMAND basic_state_and_output_vector_test)

add_executable(state_subvector_test state_subvector_test.cc)
target_link_libraries(state_subvector_test drakeSystemFramework
${GTEST_BOTH_LIBRARIES})
Expand Down
Loading

0 comments on commit 9bee5df

Please sign in to comment.