Skip to content

Commit

Permalink
[pydrake] Improve BsplineTrajectory usability (RobotLocomotion#18182)
Browse files Browse the repository at this point in the history
  • Loading branch information
RussTedrake authored Oct 26, 2022
1 parent 8fde778 commit aa2d1b8
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 10 deletions.
13 changes: 10 additions & 3 deletions bindings/pydrake/test/trajectories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,16 @@ def test_bspline_trajectory(self, T):
BsplineBasis = BsplineBasis_[T]
BsplineTrajectory = BsplineTrajectory_[T]

# Call the default constructor.
bspline = BsplineTrajectory()
self.assertIsInstance(bspline, BsplineTrajectory)
self.assertEqual(BsplineBasis().num_basis_functions(), 0)
# Call the vector<vector<T>> constructor.
bspline = BsplineTrajectory(basis=BsplineBasis(2, [0, 1, 2, 3]),
control_points=np.zeros((4, 2)))
self.assertEqual(bspline.rows(), 4)
self.assertEqual(bspline.cols(), 1)
# Call the vector<MatrixX<T>> constructor.
bspline = BsplineTrajectory(
basis=BsplineBasis(2, [0, 1, 2, 3]),
control_points=[np.zeros((3, 4)), np.ones((3, 4))])
Expand All @@ -101,9 +108,9 @@ def test_bspline_trajectory(self, T):
bspline.CopyBlock(start_row=1, start_col=2,
block_rows=2, block_cols=1),
BsplineTrajectory)
bspline = BsplineTrajectory(
basis=BsplineBasis(2, [0, 1, 2, 3]),
control_points=[np.zeros(3), np.ones(3)])
bspline = BsplineTrajectory(basis=BsplineBasis(2, [0, 1, 2, 3]),
control_points=np.array([[0, 1], [0, 1],
[0, 1]]))
self.assertIsInstance(bspline.CopyHead(n=2), BsplineTrajectory)
# Ensure we can copy.
self.assertEqual(copy.copy(bspline).rows(), 3)
Expand Down
9 changes: 9 additions & 0 deletions bindings/pydrake/trajectories_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ struct Impl {
m, "BsplineTrajectory", param, cls_doc.doc);
cls // BR
.def(py::init<>())
// This overload will match 2d numpy arrays before
// std::vector<MatrixX<T>>. We want each column of the numpy array as
// a MatrixX of control points, but the std::vectors here are
// associated with the rows in numpy.
.def(py::init([](math::BsplineBasis<T> basis,
std::vector<std::vector<T>> control_points) {
return Class(basis, MakeEigenFromRowMajorVectors(control_points));
}),
py::arg("basis"), py::arg("control_points"), cls_doc.ctor.doc)
.def(py::init<math::BsplineBasis<T>, std::vector<MatrixX<T>>>(),
py::arg("basis"), py::arg("control_points"), cls_doc.ctor.doc)
.def("Clone", &Class::Clone, cls_doc.Clone.doc)
Expand Down
8 changes: 4 additions & 4 deletions common/trajectories/bspline_trajectory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ template <typename T>
BsplineTrajectory<T>::BsplineTrajectory(BsplineBasis<T> basis,
std::vector<MatrixX<T>> control_points)
: basis_(std::move(basis)), control_points_(std::move(control_points)) {
DRAKE_DEMAND(CheckInvariants());
CheckInvariants();
}

template <typename T>
Expand Down Expand Up @@ -234,9 +234,9 @@ boolean<T> BsplineTrajectory<T>::operator==(
}

template <typename T>
bool BsplineTrajectory<T>::CheckInvariants() const {
return static_cast<int>(control_points_.size()) ==
basis_.num_basis_functions();
void BsplineTrajectory<T>::CheckInvariants() const {
DRAKE_THROW_UNLESS(static_cast<int>(control_points_.size()) ==
basis_.num_basis_functions());
}

DRAKE_DEFINE_CLASS_TEMPLATE_INSTANTIATIONS_ON_DEFAULT_SCALARS(
Expand Down
4 changes: 2 additions & 2 deletions common/trajectories/bspline_trajectory.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class BsplineTrajectory final : public trajectories::Trajectory<T> {
Serialize(Archive* a) {
a->Visit(MakeNameValue("basis", &basis_));
a->Visit(MakeNameValue("control_points", &control_points_));
DRAKE_THROW_UNLESS(CheckInvariants());
CheckInvariants();
}

private:
Expand All @@ -140,7 +140,7 @@ class BsplineTrajectory final : public trajectories::Trajectory<T> {
std::unique_ptr<trajectories::Trajectory<T>> DoMakeDerivative(
int derivative_order) const override;

bool CheckInvariants() const;
void CheckInvariants() const;

math::BsplineBasis<T> basis_;
std::vector<MatrixX<T>> control_points_;
Expand Down
2 changes: 1 addition & 1 deletion common/trajectories/test/bspline_trajectory_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ const char* const not_enough_control_points = R"""(
GTEST_TEST(BsplineTrajectorySerializeTests, NotEnoughControlPointsTest) {
DRAKE_EXPECT_THROWS_MESSAGE(
LoadYamlString<BsplineTrajectory<double>>(not_enough_control_points),
".*CheckInvariants.*");
".*num_basis_functions.*");
}

} // namespace trajectories
Expand Down

0 comments on commit aa2d1b8

Please sign in to comment.