Skip to content

Commit

Permalink
bindings: Change context.FixInputPort() to port.FixValue() (RobotLoco…
Browse files Browse the repository at this point in the history
…motion#12079)

Python bindings: Change context.FixInputPort() to port.FixValue().
  • Loading branch information
mpetersen94 authored and sherm1 committed Sep 20, 2019
1 parent 0303727 commit 7ab56bf
Show file tree
Hide file tree
Showing 17 changed files with 109 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def main():
cart_pole_context = diagram.GetMutableSubsystemContext(
cart_pole, diagram_context)

cart_pole_context.FixInputPort(
cart_pole.get_actuation_input_port().get_index(), [0])
cart_pole.get_actuation_input_port().FixValue(cart_pole_context, 0)

cart_slider = cart_pole.GetJointByName("CartSlider")
pole_pin = cart_pole.GetJointByName("PolePin")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,12 @@ def main():
linearize_context = pendulum.CreateDefaultContext()
linearize_context.SetContinuousState(
np.array([upright_theta, 0.]))
actuation_port_index = pendulum.get_actuation_input_port().get_index()
linearize_context.FixInputPort(
actuation_port_index, np.zeros(1))
actuation_port = pendulum.get_actuation_input_port()
actuation_port.FixValue(linearize_context, 0)
controller = builder.AddSystem(
LinearQuadraticRegulator(
pendulum, linearize_context, Q, R,
np.zeros(0), actuation_port_index))
np.zeros(0), actuation_port.get_index()))

# Apply the torque limit.
torque_limit = args.torque_limit
Expand Down
2 changes: 1 addition & 1 deletion bindings/pydrake/examples/test/acrobot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_simulation(self):
# Set an input torque.
input = AcrobotInput()
input.set_tau(1.)
context.FixInputPort(0, input)
acrobot.GetInputPort("elbow_torque").FixValue(context, input)

# Set the initial state.
state = context.get_mutable_continuous_state_vector()
Expand Down
2 changes: 1 addition & 1 deletion bindings/pydrake/examples/test/compass_gait_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_simulation(self):
context.SetAccuracy(1e-8)

# Zero the input
context.FixInputPort(compass_gait.get_input_port(0).get_index(), [0.0])
compass_gait.get_input_port(0).FixValue(context, 0.0)

# Set the initial state.
state = context.get_mutable_continuous_state_vector()
Expand Down
2 changes: 1 addition & 1 deletion bindings/pydrake/examples/test/pendulum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_simulation(self):
# Set an input torque.
input = PendulumInput()
input.set_tau(1.)
context.FixInputPort(0, input)
pendulum.get_input_port().FixValue(context, input)

# Set the initial state.
state = context.get_mutable_continuous_state_vector()
Expand Down
3 changes: 2 additions & 1 deletion bindings/pydrake/multibody/test/plant_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,8 @@ def test_contact_results_to_lcm(self):
plant.Finalize()
contact_results_to_lcm = ContactResultsToLcmSystem(plant)
context = contact_results_to_lcm.CreateDefaultContext()
context.FixInputPort(0, AbstractValue.Make(ContactResults_[float]()))
contact_results_to_lcm.get_input_port(0).FixValue(
context, ContactResults_[float]())
output = contact_results_to_lcm.AllocateOutput()
contact_results_to_lcm.CalcOutput(context, output)
result = output.get_data(0)
Expand Down
32 changes: 13 additions & 19 deletions bindings/pydrake/systems/test/controllers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
LinearProgrammingApproximateDynamicProgramming,
PeriodicBoundaryCondition, PidControlledSystem, PidController
)
from pydrake.systems.framework import BasicVector
from pydrake.systems.primitives import Integrator, LinearSystem


Expand Down Expand Up @@ -140,20 +139,15 @@ def test_inverse_dynamics_controller(self):
context = controller.CreateDefaultContext()
output = controller.AllocateOutput()

estimated_state_port = 0
desired_state_port = 1
desired_acceleration_port = 2
control_port = 0

self.assertEqual(
controller.get_input_port(desired_acceleration_port).size(),
kNumVelocities)
self.assertEqual(
controller.get_input_port(estimated_state_port).size(), kStateSize)
self.assertEqual(
controller.get_input_port(desired_state_port).size(), kStateSize)
self.assertEqual(
controller.get_output_port(control_port).size(), kNumVelocities)
estimated_state_port = controller.get_input_port(0)
desired_state_port = controller.get_input_port(1)
desired_acceleration_port = controller.get_input_port(2)
control_port = controller.get_output_port(0)

self.assertEqual(desired_acceleration_port.size(), kNumVelocities)
self.assertEqual(estimated_state_port.size(), kStateSize)
self.assertEqual(desired_state_port.size(), kStateSize)
self.assertEqual(control_port.size(), kNumVelocities)

# Current state.
q = np.array([-0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3])
Expand All @@ -170,9 +164,9 @@ def test_inverse_dynamics_controller(self):

vd_d = vd_r + kp*(q_r-q) + kd*(v_r-v) + ki*integral_term

context.FixInputPort(estimated_state_port, BasicVector(x))
context.FixInputPort(desired_state_port, BasicVector(x_r))
context.FixInputPort(desired_acceleration_port, BasicVector(vd_r))
estimated_state_port.FixValue(context, x)
desired_state_port.FixValue(context, x_r)
desired_acceleration_port.FixValue(context, vd_r)
controller.set_integral_value(context, integral_term)

# Set the plant's context.
Expand Down Expand Up @@ -253,7 +247,7 @@ def test_linear_quadratic_regulator(self):
np.testing.assert_almost_equal(controller.D(), -K_expected)

context = double_integrator.CreateDefaultContext()
context.FixInputPort(0, BasicVector([0]))
double_integrator.get_input_port(0).FixValue(context, [0])
controller = LinearQuadraticRegulator(double_integrator, context, Q, R)
np.testing.assert_almost_equal(controller.D(), -K_expected)

Expand Down
31 changes: 20 additions & 11 deletions bindings/pydrake/systems/test/custom_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
call_vector_system_overrides,
)

from pydrake.common.test_utilities import numpy_compare
from pydrake.common.test_utilities.deprecation import catch_drake_warnings


Expand Down Expand Up @@ -128,10 +129,10 @@ def _create_adder_system(self):
system = CustomAdder(2, 3)
return system

def _fix_adder_inputs(self, context):
def _fix_adder_inputs(self, system, context):
self.assertEqual(context.num_input_ports(), 2)
context.FixInputPort(0, BasicVector([1, 2, 3]))
context.FixInputPort(1, BasicVector([4, 5, 6]))
system.get_input_port(0).FixValue(context, [1, 2, 3])
system.get_input_port(1).FixValue(context, [4, 5, 6])

def test_diagram_adder(self):
system = CustomDiagram(2, 3)
Expand All @@ -144,7 +145,7 @@ def test_adder_execution(self):
system = self._create_adder_system()
context = system.CreateDefaultContext()
self.assertEqual(context.num_output_ports(), 1)
self._fix_adder_inputs(context)
self._fix_adder_inputs(system, context)
output = system.AllocateOutput()
self.assertEqual(output.num_ports(), 1)
system.CalcOutput(context, output)
Expand All @@ -164,7 +165,7 @@ def test_adder_simulation(self):
builder.Connect(adder.get_output_port(0), zoh.get_input_port(0))
diagram = builder.Build()
context = diagram.CreateDefaultContext()
self._fix_adder_inputs(context)
self._fix_adder_inputs(diagram, context)

simulator = Simulator(diagram, context)
simulator.Initialize()
Expand Down Expand Up @@ -434,7 +435,7 @@ def test_vector_system_overrides(self):
context = system.CreateDefaultContext()

u = np.array([1.])
context.FixInputPort(0, BasicVector(u))
system.get_input_port(0).FixValue(context, u)

# Dispatch virtual calls from C++.
output = call_vector_system_overrides(
Expand Down Expand Up @@ -604,6 +605,13 @@ def __init__(self, index):

def test_abstract_io_port(self):
test = self

def assert_value_equal(a, b):
a_name, a_value = a
b_name, b_value = b
self.assertEqual(a_name, b_name)
numpy_compare.assert_equal(a_value, b_value)

# N.B. Since this has trivial operations, we can test all scalar types.
for T in [float, AutoDiffXd, Expression]:
default_value = ("default", T(0.))
Expand All @@ -627,16 +635,16 @@ def DoCalcAbstractOutput(self, context, y_data):
context, 0).get_value()
# The allocator function will populate the output with
# the "input"
test.assertTupleEqual(input_value, expected_input_value)
assert_value_equal(input_value, expected_input_value)
y_data.set_value(expected_output_value)
test.assertTupleEqual(y_data.get_value(),
expected_output_value)
assert_value_equal(
y_data.get_value(), expected_output_value)

system = CustomAbstractSystem()
context = system.CreateDefaultContext()

self.assertEqual(context.num_input_ports(), 1)
context.FixInputPort(0, AbstractValue.Make(expected_input_value))
system.get_input_port(0).FixValue(context, expected_input_value)
output = system.AllocateOutput()
self.assertEqual(output.num_ports(), 1)
system.CalcOutput(context, output)
Expand Down Expand Up @@ -667,6 +675,7 @@ def _Out(self, context, y_data):
system = ParseFloatSystem()
context = system.CreateDefaultContext()
output = system.AllocateOutput()
context.FixInputPort(0, AbstractValue.Make(["22.2"]))
system.get_input_port(0).FixValue(context,
AbstractValue.Make(["22.2"]))
system.CalcOutput(context, output)
self.assertEqual(output.get_vector_data(0).GetAtIndex(0), 22.2)
31 changes: 25 additions & 6 deletions bindings/pydrake/systems/test/general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
LinearSystem,
PassThrough,
SignalLogger,
ZeroOrderHold,
)
from pydrake.common.test_utilities.deprecation import catch_drake_warnings

Expand Down Expand Up @@ -269,7 +270,7 @@ def test_instantiations(self):
def test_scalar_type_conversion(self):
float_system = Adder(1, 1)
float_context = float_system.CreateDefaultContext()
float_context.FixInputPort(0, [1.])
float_system.get_input_port(0).FixValue(float_context, 1.)
for T in [float, AutoDiffXd, Expression]:
system = Adder_[T](1, 1)
# N.B. Current scalar conversion does not permit conversion to and
Expand Down Expand Up @@ -397,11 +398,12 @@ def test_diagram_simulation(self):
# TODO(eric.cousineau): Not seeing any assertions being printed if no
# inputs are connected. Need to check this behavior.
input0 = np.array([0.1, 0.2, 0.3])
context.FixInputPort(0, input0)
diagram.get_input_port(0).FixValue(context, input0)
input1 = np.array([0.02, 0.03, 0.04])
context.FixInputPort(1, input1)
diagram.get_input_port(1).FixValue(context, input1)
# Test the BasicVector overload.
input2 = BasicVector([0.003, 0.004, 0.005])
context.FixInputPort(2, input2) # Test the BasicVector overload.
diagram.get_input_port(2).FixValue(context, input2)

# Test __str__ methods.
self.assertRegexpMatches(str(context), "integrator")
Expand Down Expand Up @@ -527,7 +529,8 @@ def test_abstract_input_port_eval(self):
model_value = AbstractValue.Make("Hello World")
system = PassThrough(copy.copy(model_value))
context = system.CreateDefaultContext()
fixed = context.FixInputPort(0, copy.copy(model_value))
fixed = system.get_input_port(0).FixValue(context,
copy.copy(model_value))
self.assertIsInstance(fixed.GetMutableData(), AbstractValue)
input_port = system.get_input_port(0)

Expand All @@ -544,7 +547,7 @@ def test_vector_input_port_eval(self):
model_value = AbstractValue.Make(BasicVector(np_value))
system = PassThrough(len(np_value))
context = system.CreateDefaultContext()
context.FixInputPort(0, np_value)
system.get_input_port(0).FixValue(context, np_value)
input_port = system.get_input_port(0)

value = input_port.Eval(context)
Expand Down Expand Up @@ -663,3 +666,19 @@ def test_vector_input_port_fix(self):
# A RuntimeError occurs when the Context detects that the
# type-erased Value objects are incompatible.
input_port.FixValue(context, AbstractValue.Make("string"))

def test_context_fix_input_port(self):
# WARNING: This is not the recommend workflow; instead, use
# `InputPort.FixValue` instead. This is here just for testing /
# coverage purposes.
dt = 0.1 # Arbitrary.
system_vec = ZeroOrderHold(period_sec=dt, vector_size=1)
context_vec = system_vec.CreateDefaultContext()
context_vec.FixInputPort(index=0, data=[0.])
context_vec.FixInputPort(index=0, vec=BasicVector([0.]))
# Test abstract.
model_value = AbstractValue.Make("Hello")
system_abstract = ZeroOrderHold(
period_sec=dt, abstract_model_value=model_value.Clone())
context_abstract = system_abstract.CreateDefaultContext()
context_abstract.FixInputPort(index=0, value=model_value.Clone())
2 changes: 1 addition & 1 deletion bindings/pydrake/systems/test/lcm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_subscriber_wait_for_message(self):

def _fix_and_publish(self, dut, value):
context = dut.CreateDefaultContext()
context.FixInputPort(0, value)
dut.get_input_port(0).FixValue(context, value)
dut.Publish(context)

def test_publisher(self):
Expand Down
2 changes: 1 addition & 1 deletion bindings/pydrake/systems/test/lifetime_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def test_ownership_vector(self):
context = system.CreateDefaultContext()
info = Info()
vector = DeleteListenerVector(info.record_deletion)
context.FixInputPort(0, vector)
system.get_input_port(0).FixValue(context, vector)
del context
# Same as above applications, using `py::keep_alive`.
self.assertFalse(info.deleted)
Expand Down
21 changes: 10 additions & 11 deletions bindings/pydrake/systems/test/meshcat_visualizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def test_cart_pole(self):
cart_pole_context = diagram.GetMutableSubsystemContext(
cart_pole, diagram_context)

cart_pole_context.FixInputPort(
cart_pole.get_actuation_input_port().get_index(), [0])
cart_pole.get_actuation_input_port().FixValue(cart_pole_context, 0)

cart_slider = cart_pole.GetJointByName("CartSlider")
pole_pin = cart_pole.GetJointByName("PolePin")
Expand Down Expand Up @@ -101,9 +100,9 @@ def test_kuka(self):
kuka_context = diagram.GetMutableSubsystemContext(
kuka, diagram_context)

kuka_context.FixInputPort(
kuka.get_actuation_input_port().get_index(), np.zeros(
kuka.get_actuation_input_port().size()))
kuka_actuation_port = kuka.get_actuation_input_port()
kuka_actuation_port.FixValue(kuka_context,
np.zeros(kuka_actuation_port.size()))

simulator = Simulator(diagram, diagram_context)
simulator.set_publish_every_time_step(False)
Expand Down Expand Up @@ -253,15 +252,15 @@ def show_cloud(pc, pc2=None, use_native=False, **kwargs):
diagram_context = diagram.CreateDefaultContext()
context = diagram.GetMutableSubsystemContext(
pc_viz, diagram_context)
context.FixInputPort(
pc_viz.GetInputPort("point_cloud_P").get_index(),
AbstractValue.Make(pc))
# TODO(eric.cousineau): Replace `AbstractValue.Make(pc)` with just
# `pc` (#12086).
pc_viz.GetInputPort("point_cloud_P").FixValue(
context, AbstractValue.Make(pc))
if pc2:
context = diagram.GetMutableSubsystemContext(
pc_viz2, diagram_context)
context.FixInputPort(
pc_viz2.GetInputPort("point_cloud_P").get_index(),
AbstractValue.Make(pc2))
pc_viz2.GetInputPort("point_cloud_P").FixValue(
context, AbstractValue.Make(pc2))
simulator = Simulator(diagram, diagram_context)
simulator.set_publish_every_time_step(False)
simulator.AdvanceTo(sim_time)
Expand Down
Loading

0 comments on commit 7ab56bf

Please sign in to comment.