Skip to content

Commit

Permalink
Rename timestep => time_step throughout trajectory optimization (Robo…
Browse files Browse the repository at this point in the history
  • Loading branch information
RussTedrake authored May 23, 2023
1 parent 3c2363e commit 04f0eb8
Show file tree
Hide file tree
Showing 12 changed files with 440 additions and 175 deletions.
1 change: 1 addition & 0 deletions bindings/pydrake/planning/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ drake_py_unittest(
name = "trajectory_optimization_test",
deps = [
":planning",
"//bindings/pydrake/common/test_utilities:deprecation_py",
"//bindings/pydrake/examples",
],
)
Expand Down
56 changes: 51 additions & 5 deletions bindings/pydrake/planning/planning_py_trajectory_optimization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"

#include "drake/bindings/pydrake/common/deprecation_pybind.h"
#include "drake/bindings/pydrake/documentation_pybind.h"
#include "drake/bindings/pydrake/geometry/optimization_pybind.h"
#include "drake/bindings/pydrake/pydrake_pybind.h"
Expand Down Expand Up @@ -57,10 +58,10 @@ void DefinePlanningTrajectoryOptimization(py::module m) {
.def("time", &Class::time, cls_doc.time.doc)
.def("prog", overload_cast_explicit<MathematicalProgram&>(&Class::prog),
py_rvp::reference_internal, cls_doc.prog.doc)
.def("timestep", &Class::timestep, py::arg("index"),
cls_doc.timestep.doc)
.def("fixed_timestep", &Class::fixed_timestep,
cls_doc.fixed_timestep.doc)
.def("time_step", &Class::time_step, py::arg("index"),
cls_doc.time_step.doc)
.def("fixed_time_step", &Class::fixed_time_step,
cls_doc.fixed_time_step.doc)
// TODO(eric.cousineau): The original bindings returned references
// instead of copies using VectorXBlock. Restore this once dtype=custom
// is resolved.
Expand Down Expand Up @@ -200,6 +201,18 @@ void DefinePlanningTrajectoryOptimization(py::module m) {
const solvers::MathematicalProgramResult&>(
&Class::ReconstructStateTrajectory),
py::arg("result"), cls_doc.ReconstructStateTrajectory.doc);

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
cls.def("timestep",
WrapDeprecated(cls_doc.timestep.doc_deprecated, &Class::timestep),
py::arg("index"), cls_doc.timestep.doc_deprecated)
.def("fixed_timestep",
WrapDeprecated(
cls_doc.fixed_timestep.doc_deprecated, &Class::fixed_timestep),
cls_doc.fixed_timestep.doc_deprecated);
#pragma GCC diagnostic pop

RegisterAddConstraintToAllKnotPoints<solvers::BoundingBoxConstraint>(&cls);
RegisterAddConstraintToAllKnotPoints<solvers::LinearEqualityConstraint>(
&cls);
Expand All @@ -217,6 +230,21 @@ void DefinePlanningTrajectoryOptimization(py::module m) {
systems::InputPortIndex>,
bool, solvers::MathematicalProgram*>(),
py::arg("system"), py::arg("context"), py::arg("num_time_samples"),
py::arg("minimum_time_step"), py::arg("maximum_time_step"),
py::arg("input_port_index") =
systems::InputPortSelection::kUseFirstInputIfItExists,
py::arg("assume_non_continuous_states_are_fixed") = false,
py::arg("prog") = nullptr, cls_doc.ctor.doc)
.def(
py_init_deprecated<Class, const systems::System<double>*,
const systems::Context<double>&, int, double, double,
std::variant<systems::InputPortSelection,
systems::InputPortIndex>,
bool, solvers::MathematicalProgram*>(
"The arguments minimum_timestep and maximum_timestep have been "
"renamed to minimum_time_step and maximum_time_step. This "
"version will be removed on or after 2023-09-01."),
py::arg("system"), py::arg("context"), py::arg("num_time_samples"),
py::arg("minimum_timestep"), py::arg("maximum_timestep"),
py::arg("input_port_index") =
systems::InputPortSelection::kUseFirstInputIfItExists,
Expand All @@ -242,6 +270,13 @@ void DefinePlanningTrajectoryOptimization(py::module m) {
}

m.def("AddDirectCollocationConstraint", &AddDirectCollocationConstraint,
py::arg("constraint"), py::arg("time_step"), py::arg("state"),
py::arg("next_state"), py::arg("input"), py::arg("next_input"),
py::arg("prog"), doc.AddDirectCollocationConstraint.doc);
m.def("AddDirectCollocationConstraint",
WrapDeprecated("Argument timestep has been renamed to time_step. This "
"version will be removed on or after 2023-09-01.",
AddDirectCollocationConstraint),
py::arg("constraint"), py::arg("timestep"), py::arg("state"),
py::arg("next_state"), py::arg("input"), py::arg("next_input"),
py::arg("prog"), doc.AddDirectCollocationConstraint.doc);
Expand All @@ -268,12 +303,23 @@ void DefinePlanningTrajectoryOptimization(py::module m) {
py::arg("input_port_index") =
systems::InputPortSelection::kUseFirstInputIfItExists,
cls_doc.ctor.doc_4args)
.def(py_init_deprecated<Class, const systems::System<double>*,
const systems::Context<double>&, int, TimeStep,
std::variant<systems::InputPortSelection,
systems::InputPortIndex>>(
"Argument fixed_timestep has been renamed to fixed_time_step. "
"This version will be removed on or after 2023-09-01."),
py::arg("system"), py::arg("context"), py::arg("num_time_samples"),
py::arg("fixed_timestep"),
py::arg("input_port_index") =
systems::InputPortSelection::kUseFirstInputIfItExists,
cls_doc.ctor.doc_5args)
.def(py::init<const systems::System<double>*,
const systems::Context<double>&, int, TimeStep,
std::variant<systems::InputPortSelection,
systems::InputPortIndex>>(),
py::arg("system"), py::arg("context"), py::arg("num_time_samples"),
py::arg("fixed_timestep"),
py::arg("fixed_time_step"),
py::arg("input_port_index") =
systems::InputPortSelection::kUseFirstInputIfItExists,
cls_doc.ctor.doc_5args);
Expand Down
199 changes: 191 additions & 8 deletions bindings/pydrake/planning/test/trajectory_optimization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from pydrake.common.test_utilities.deprecation import catch_drake_warnings
from pydrake.examples import PendulumPlant
from pydrake.math import eq, BsplineBasis
from pydrake.planning import (
Expand Down Expand Up @@ -50,8 +51,8 @@ def test_direct_collocation(self):
plant,
context,
num_time_samples=num_time_samples,
minimum_timestep=0.2,
maximum_timestep=0.5,
minimum_time_step=0.2,
maximum_time_step=0.5,
input_port_index=InputPortSelection.kUseFirstInputIfItExists,
assume_non_continuous_states_are_fixed=False)
prog = dircol.prog()
Expand All @@ -61,7 +62,7 @@ def test_direct_collocation(self):
# as a consistent optimization. The goal is to check the bindings,
# not the implementation.
t = dircol.time()
dt = dircol.timestep(index=0)
dt = dircol.time_step(index=0)
x = dircol.state()
x2 = dircol.state(index=2)
x0 = dircol.initial_state()
Expand Down Expand Up @@ -125,7 +126,7 @@ def complete_callback(t, x, u, v):
dircol.ReconstructStateTrajectory(result=result)

constraint = DirectCollocationConstraint(plant, context)
AddDirectCollocationConstraint(constraint, dircol.timestep(0),
AddDirectCollocationConstraint(constraint, dircol.time_step(0),
dircol.state(0), dircol.state(1),
dircol.input(0), dircol.input(1),
prog)
Expand Down Expand Up @@ -161,8 +162,8 @@ def complete_callback(t, x, u, v):
plant,
context,
num_time_samples=num_time_samples,
minimum_timestep=0.2,
maximum_timestep=0.5,
minimum_time_step=0.2,
maximum_time_step=0.5,
input_port_index=InputPortSelection.kUseFirstInputIfItExists,
assume_non_continuous_states_are_fixed=False,
prog=prog)
Expand All @@ -182,7 +183,7 @@ def test_direct_transcription(self):
# as a consistent optimization. The goal is to check the bindings,
# not the implementation.
t = dirtran.time()
dt = dirtran.fixed_timestep()
dt = dirtran.fixed_time_step()
x = dirtran.state()
x2 = dirtran.state(2)
x0 = dirtran.initial_state()
Expand Down Expand Up @@ -213,7 +214,189 @@ def test_direct_transcription(self):
context = plant.CreateDefaultContext()
dirtran = DirectTranscription(
plant, context, num_time_samples=21,
fixed_timestep=DirectTranscription.TimeStep(0.1))
fixed_time_step=DirectTranscription.TimeStep(0.1))

def test_direct_collocation_deprecated(self):
plant = PendulumPlant()
context = plant.CreateDefaultContext()

num_time_samples = 21
with catch_drake_warnings(expected_count=1) as w:
dircol = DirectCollocation(
plant,
context,
num_time_samples=num_time_samples,
minimum_timestep=0.2,
maximum_timestep=0.5,
input_port_index=InputPortSelection.kUseFirstInputIfItExists,
assume_non_continuous_states_are_fixed=False)
prog = dircol.prog()
num_initial_vars = prog.num_vars()

# Spell out most of the methods, regardless of whether they make sense
# as a consistent optimization. The goal is to check the bindings,
# not the implementation.
t = dircol.time()
with catch_drake_warnings(expected_count=1) as w:
dt = dircol.timestep(index=0)
x = dircol.state()
x2 = dircol.state(index=2)
x0 = dircol.initial_state()
xf = dircol.final_state()
u = dircol.input()
u2 = dircol.input(index=2)
v = dircol.NewSequentialVariable(rows=1, name="test")
v2 = dircol.GetSequentialVariableAtIndex(name="test", index=2)

dircol.AddRunningCost(x.dot(x))
input_con = dircol.AddConstraintToAllKnotPoints(u[0] == 0)
self.assertEqual(len(input_con), 21)
interval_bound = dircol.AddTimeIntervalBounds(
lower_bound=0.3, upper_bound=0.4)
self.assertIsInstance(interval_bound.evaluator(),
mp.BoundingBoxConstraint)
equal_time_con = dircol.AddEqualTimeIntervalsConstraints()
self.assertEqual(len(equal_time_con), 19)
duration_bound = dircol.AddDurationBounds(
lower_bound=0.3*21, upper_bound=0.4*21)
self.assertIsInstance(duration_bound.evaluator(), mp.LinearConstraint)
final_cost = dircol.AddFinalCost(2*x.dot(x))
self.assertIsInstance(final_cost.evaluator(), mp.Cost)

initial_u = PiecewisePolynomial.ZeroOrderHold([0, 0.3*21],
np.zeros((1, 2)))
initial_x = PiecewisePolynomial()
dircol.SetInitialTrajectory(traj_init_u=initial_u,
traj_init_x=initial_x)

was_called = dict(
input=False,
state=False,
complete=False
)

def input_callback(t, u):
was_called["input"] = True

def state_callback(t, x):
was_called["state"] = True

def complete_callback(t, x, u, v):
was_called["complete"] = True

dircol.AddInputTrajectoryCallback(callback=input_callback)
dircol.AddStateTrajectoryCallback(callback=state_callback)
dircol.AddCompleteTrajectoryCallback(callback=complete_callback,
names=["test"])

result = mp.Solve(dircol.prog())
self.assertTrue(was_called["input"])
self.assertTrue(was_called["state"])
self.assertTrue(was_called["complete"])

dircol.GetSampleTimes(result=result)
dircol.GetInputSamples(result=result)
dircol.GetStateSamples(result=result)
dircol.GetSequentialVariableSamples(result=result, name="test")
dircol.ReconstructInputTrajectory(result=result)
dircol.ReconstructStateTrajectory(result=result)

constraint = DirectCollocationConstraint(plant, context)
with catch_drake_warnings(expected_count=1) as w:
AddDirectCollocationConstraint(constraint, dircol.timestep(0),
dircol.state(0), dircol.state(1),
dircol.input(0), dircol.input(1),
prog)

# Test AddConstraintToAllKnotPoints variants.
nc = len(prog.bounding_box_constraints())
c = dircol.AddConstraintToAllKnotPoints(
constraint=mp.BoundingBoxConstraint([0], [1]), vars=u)
self.assertIsInstance(c[0], mp.Binding[mp.BoundingBoxConstraint])
self.assertEqual(len(prog.bounding_box_constraints()),
nc + num_time_samples)
nc = len(prog.linear_equality_constraints())
c = dircol.AddConstraintToAllKnotPoints(
constraint=mp.LinearEqualityConstraint([1], [0]), vars=u)
self.assertIsInstance(c[0], mp.Binding[mp.LinearEqualityConstraint])
self.assertEqual(len(prog.linear_equality_constraints()),
nc + num_time_samples)
nc = len(prog.linear_constraints())
c = dircol.AddConstraintToAllKnotPoints(
constraint=mp.LinearConstraint([1], [0], [1]), vars=u)
self.assertIsInstance(c[0], mp.Binding[mp.LinearConstraint])
self.assertEqual(len(prog.linear_constraints()), nc + num_time_samples)
nc = len(prog.linear_equality_constraints())
# eq(x, 2) produces a 2-dimensional vector of Formula.
c = dircol.AddConstraintToAllKnotPoints(eq(x, 2))
self.assertIsInstance(c[0].evaluator(), mp.LinearEqualityConstraint)
self.assertEqual(len(prog.linear_equality_constraints()),
nc + 2*num_time_samples)

# Add a second direct collocation problem to the same prog.
num_vars = prog.num_vars()
with catch_drake_warnings(expected_count=1) as w:
dircol2 = DirectCollocation(
plant,
context,
num_time_samples=num_time_samples,
minimum_timestep=0.2,
maximum_timestep=0.5,
input_port_index=InputPortSelection.kUseFirstInputIfItExists,
assume_non_continuous_states_are_fixed=False,
prog=prog)
self.assertEqual(dircol.prog(), dircol2.prog())
self.assertEqual(prog.num_vars(), num_vars + num_initial_vars)

def test_direct_transcription_deprecated(self):
# Integrator.
plant = LinearSystem(
A=[0.0], B=[1.0], C=[1.0], D=[0.0], time_period=0.1)
context = plant.CreateDefaultContext()

# Constructor for discrete systems.
dirtran = DirectTranscription(plant, context, num_time_samples=21)

# Spell out most of the methods, regardless of whether they make sense
# as a consistent optimization. The goal is to check the bindings,
# not the implementation.
t = dirtran.time()
with catch_drake_warnings(expected_count=1) as w:
dt = dirtran.fixed_timestep()
x = dirtran.state()
x2 = dirtran.state(2)
x0 = dirtran.initial_state()
xf = dirtran.final_state()
u = dirtran.input()
u2 = dirtran.input(2)

dirtran.AddRunningCost(x.dot(x))
dirtran.AddConstraintToAllKnotPoints(u[0] == 0)
dirtran.AddFinalCost(2*x.dot(x))

initial_u = PiecewisePolynomial.ZeroOrderHold([0, 0.3*21],
np.zeros((1, 2)))
initial_x = PiecewisePolynomial()
dirtran.SetInitialTrajectory(initial_u, initial_x)

result = mp.Solve(dirtran.prog())
times = dirtran.GetSampleTimes(result)
inputs = dirtran.GetInputSamples(result)
states = dirtran.GetStateSamples(result)
input_traj = dirtran.ReconstructInputTrajectory(result)
state_traj = dirtran.ReconstructStateTrajectory(result)

# Confirm that the constructor for continuous systems works (and
# confirm binding of nested TimeStep).
plant = LinearSystem(
A=[0.0], B=[1.0], C=[1.0], D=[0.0], time_period=0.0)
context = plant.CreateDefaultContext()
with catch_drake_warnings(expected_count=1) as w:
dirtran = DirectTranscription(
plant,
context,
num_time_samples=21,
fixed_timestep=DirectTranscription.TimeStep(0.1))

def test_kinematic_trajectory_optimization(self):
trajopt = KinematicTrajectoryOptimization(num_positions=2,
Expand Down
Loading

0 comments on commit 04f0eb8

Please sign in to comment.