From 91e3dccea12c6b5fa73ec215d8af0ff4ee7f1876 Mon Sep 17 00:00:00 2001 From: Jeremy Nimmer Date: Tue, 14 Mar 2023 14:42:00 -0700 Subject: [PATCH] [pydrake] Use nice names for default template classes (#18972) When a template class has a default type parameter, name the default instantiation directly using the default name, instead of via an alias. This makes IDE auto-complete and type annotations more natural. --- bindings/pydrake/_math_extra.py | 16 ++------ bindings/pydrake/common/__init__.py | 7 ++++ bindings/pydrake/common/cpp_template.py | 35 +++++++++-------- bindings/pydrake/common/cpp_template_pybind.h | 39 ++++++++++++------- .../common/test/eigen_geometry_test.py | 2 +- bindings/pydrake/multibody/_math_extra.py | 10 +---- bindings/pydrake/multibody/test/plant_test.py | 10 ++--- bindings/pydrake/systems/scalar_conversion.py | 4 +- bindings/pydrake/systems/test/value_test.py | 7 ++-- doc/pydrake/pydrake_sphinx_extension.py | 21 +++------- 10 files changed, 71 insertions(+), 80 deletions(-) diff --git a/bindings/pydrake/_math_extra.py b/bindings/pydrake/_math_extra.py index bdf185164a11..86b423e9a013 100644 --- a/bindings/pydrake/_math_extra.py +++ b/bindings/pydrake/_math_extra.py @@ -112,27 +112,18 @@ def _indented_repr(o): return repr(o).replace("\n", "\n ") -def _remove_float_suffix(typename): - suffix = "_[float]" - if typename.endswith(suffix): - return typename[:-len(suffix)] - return typename - - def _roll_pitch_yaw_repr(rpy): - cls_name = _remove_float_suffix(_pretty_class_name(type(rpy))) return ( - f"{cls_name}(" + f"{_pretty_class_name(type(rpy))}(" f"roll={repr(rpy.roll_angle())}, " f"pitch={repr(rpy.pitch_angle())}, " f"yaw={repr(rpy.yaw_angle())})") def _rotation_matrix_repr(R): - cls_name = _remove_float_suffix(_pretty_class_name(type(R))) M = R.matrix().tolist() return ( - f"{cls_name}([\n" + f"{_pretty_class_name(type(R))}([\n" f" {_indented_repr(M[0])},\n" f" {_indented_repr(M[1])},\n" f" {_indented_repr(M[2])},\n" @@ -140,9 +131,8 @@ def _rotation_matrix_repr(R): def _rigid_transform_repr(X): - cls_name = _remove_float_suffix(_pretty_class_name(type(X))) return ( - f"{cls_name}(\n" + f"{_pretty_class_name(type(X))}(\n" f" R={_indented_repr(X.rotation())},\n" f" p={_indented_repr(X.translation().tolist())},\n" f")") diff --git a/bindings/pydrake/common/__init__.py b/bindings/pydrake/common/__init__.py index b26a45fb33e6..83a6797c4415 100644 --- a/bindings/pydrake/common/__init__.py +++ b/bindings/pydrake/common/__init__.py @@ -227,6 +227,13 @@ def __getattr__(name): name = _MangledName.mangle(name) if name in module_globals: return module_globals[name] + float_tag = "_{}float{}".format( + _MangledName.UNICODE_LEFT_BRACKET, + _MangledName.UNICODE_RIGHT_BRACKET) + if name.endswith(float_tag): + shorter_name = name[:-len(float_tag)] + if shorter_name in module_globals: + return module_globals[shorter_name] raise AttributeError( f"module {module_name!r} has no attribute {name!r}") diff --git a/bindings/pydrake/common/cpp_template.py b/bindings/pydrake/common/cpp_template.py index e837aa6df5b9..940603d037a9 100644 --- a/bindings/pydrake/common/cpp_template.py +++ b/bindings/pydrake/common/cpp_template.py @@ -174,7 +174,8 @@ def get_instantiation(self, param=None, throw_error=True): if instantiation is TemplateBase._deferred: assert self._instantiation_func is not None instantiation = self._instantiation_func(param) - self._add_instantiation_internal(param, instantiation) + self._add_instantiation_internal(param, instantiation, + skip_rename=False) elif instantiation is None and throw_error: raise RuntimeError("Invalid instantiation: {}".format( self._instantiation_name(param))) @@ -183,7 +184,7 @@ def get_instantiation(self, param=None, throw_error=True): _warn_deprecated(deprecation.message, date=deprecation.date) return (instantiation, param) - def add_instantiation(self, param, instantiation): + def add_instantiation(self, param, instantiation, skip_rename=False): """Adds a unique instantiation. Note: @@ -196,15 +197,15 @@ def add_instantiation(self, param, instantiation): "Parameter instantiation already registered: {}".format(param)) # Register it. self.param_list.append(param) - self._add_instantiation_internal(param, instantiation) + self._add_instantiation_internal(param, instantiation, skip_rename) return param - def _add_instantiation_internal(self, param, instantiation): + def _add_instantiation_internal(self, param, instantiation, skip_rename): # Adds instantiation. Permits overwriting for deferred cases. assert instantiation is not None if instantiation is not TemplateBase._deferred: old = instantiation - instantiation = self._on_add(param, instantiation) + instantiation = self._on_add(param, instantiation, skip_rename) assert instantiation is not None, (self, param, old) if instantiation is not old: self._instantiation_alias_map[old] = instantiation @@ -322,7 +323,7 @@ def __str__(self): cls_name = pretty_class_name(type(self)) return "<{} {}>".format(cls_name, self._full_name()) - def _on_add(self, param, instantiation): + def _on_add(self, param, instantiation, skip_rename): # To be overridden by child classes. return instantiation @@ -370,18 +371,16 @@ def decorator(instantiation_func): class TemplateClass(TemplateBase): """Extension of `TemplateBase` for classes.""" - def __init__(self, name, override_meta=True, scope=None, **kwargs): + def __init__(self, name, *, scope=None, **kwargs): if scope is None: scope = _get_module_from_stack() TemplateBase.__init__(self, name, scope=scope, **kwargs) - self._override_meta = override_meta - - def _on_add(self, param, cls): - if self._override_meta: - # Rename the class now to reflect its `template_name` and `param`. - # C++ templates are initially bound using a `TemporaryClassName()` - # which we overwrite here. Python templates are usually declared as - # a nested class helper, which likewise we need to replace. + + def _on_add(self, param, cls, skip_rename): + # Unless this class was a default template instantiation, we need to + # rename it now to describe its template arguments. (Most templated + # C++ classes are bound using the TemporaryClassName() function.) + if not skip_rename: cls._original_name = cls.__name__ cls._original_qualname = getattr(cls, "__qualname__", cls.__name__) cls.__name__ = self._instantiation_name(param, mangle=True) @@ -444,7 +443,8 @@ def f(*args, **kwargs): return orig(*args, **kwargs) class TemplateFunction(TemplateBase): """Extension of `TemplateBase` for functions.""" - def _on_add(self, param, func): + def _on_add(self, param, func, skip_rename): + assert skip_rename is False new_name = self._instantiation_name(param, mangle=True) func = _rename_callable(func, self._scope, new_name) setattr(self._scope, func.__name__, func) @@ -461,7 +461,8 @@ def __init__(self, name, cls, scope=None, **kwargs): # only. self._cls = cls - def _on_add(self, param, func): + def _on_add(self, param, func, skip_rename): + assert skip_rename is False new_name = self._instantiation_name(param, mangle=True) func = _rename_callable(func, self._scope, new_name, self._cls) setattr(self._cls, func.__name__, func) diff --git a/bindings/pydrake/common/cpp_template_pybind.h b/bindings/pydrake/common/cpp_template_pybind.h index 4dd42315ed39..efe8ae150764 100644 --- a/bindings/pydrake/common/cpp_template_pybind.h +++ b/bindings/pydrake/common/cpp_template_pybind.h @@ -31,9 +31,9 @@ inline py::object GetOrInitTemplate( // BR } // Adds instantiation to a Python template. -inline void AddInstantiation( - py::handle py_template, py::handle obj, py::tuple param) { - py_template.attr("add_instantiation")(param, obj); +inline void AddInstantiation(py::handle py_template, py::handle obj, + py::tuple param, bool skip_rename = false) { + py_template.attr("add_instantiation")(param, obj, skip_rename); } // Gets name for a given instantiation. @@ -69,10 +69,10 @@ std::string TemporaryClassName(const std::string& name = "TemporaryName") { /// @param param Parameters for the instantiation. inline py::object AddTemplateClass( // BR py::handle scope, const std::string& template_name, py::handle py_class, - py::tuple param) { + py::tuple param, bool skip_rename = false) { py::object py_template = internal::GetOrInitTemplate(scope, template_name, "TemplateClass"); - internal::AddInstantiation(py_template, py_class, param); + internal::AddInstantiation(py_template, py_class, param, skip_rename); return py_template; } @@ -85,16 +85,29 @@ template py::class_ DefineTemplateClassWithDefault( // BR py::handle scope, const std::string& default_name, py::tuple param, const char* doc_string = "", const std::string& template_suffix = "_") { + // The default instantiation is immediately assigned its correct class name. + // Other instantiations use a temporary name here that will be overwritten + // by the AddTemplateClass function during registration. + const bool is_default = !py::hasattr(scope, default_name.c_str()); + const std::string class_name = + is_default ? default_name : TemporaryClassName(); const std::string template_name = default_name + template_suffix; - // Define class with temporary name. - py::class_ py_class( - scope, TemporaryClassName().c_str(), doc_string); - // Register instantiation. - AddTemplateClass(scope, template_name, py_class, param); - // Declare default instantiation if it does not already exist. - if (!py::hasattr(scope, default_name.c_str())) { - scope.attr(default_name.c_str()) = py_class; + // Define the class. + std::string doc; + if (is_default) { + doc = fmt::format( + "{}\n\nNote:\n\n" + " This class is templated; see :class:`{}`\n" + " for the list of instantiations.", + doc_string, template_name); + } else { + doc = doc_string; } + py::class_ py_class( + scope, class_name.c_str(), doc.c_str()); + // Register it as a template instantiation. + const bool skip_rename = is_default; + AddTemplateClass(scope, template_name, py_class, param, skip_rename); return py_class; } diff --git a/bindings/pydrake/common/test/eigen_geometry_test.py b/bindings/pydrake/common/test/eigen_geometry_test.py index 6fd976d187e3..47e25776c44a 100644 --- a/bindings/pydrake/common/test/eigen_geometry_test.py +++ b/bindings/pydrake/common/test/eigen_geometry_test.py @@ -51,7 +51,7 @@ def test_quaternion(self, T): if T == float: self.assertEqual( str(q_identity), - "Quaternion_[float](w=1.0, x=0.0, y=0.0, z=0.0)") + "Quaternion(w=1.0, x=0.0, y=0.0, z=0.0)") else: self.assertIn("Quaternion_[", str(q_identity)) self.check_cast(mut.Quaternion_, T) diff --git a/bindings/pydrake/multibody/_math_extra.py b/bindings/pydrake/multibody/_math_extra.py index 969f570cf059..08c6aba11ad5 100644 --- a/bindings/pydrake/multibody/_math_extra.py +++ b/bindings/pydrake/multibody/_math_extra.py @@ -11,21 +11,13 @@ def _indented_repr(o): return repr(o).replace("\n", "\n ") -def _remove_float_suffix(typename): - suffix = "_[float]" - if typename.endswith(suffix): - return typename[:-len(suffix)] - return typename - - def _spatial_vector_repr(rotation_name, translation_name): def repr_with_closure(self): - cls_name = _remove_float_suffix(_pretty_class_name(type(self))) rotation = self.rotational().tolist() translation = self.translational().tolist() return ( - f"{cls_name}(\n" + f"{_pretty_class_name(type(self))}(\n" f" {rotation_name}={_indented_repr(rotation)},\n" f" {translation_name}={_indented_repr(translation)},\n" f")") diff --git a/bindings/pydrake/multibody/test/plant_test.py b/bindings/pydrake/multibody/test/plant_test.py index e41fdad68f75..ca64fc12c57e 100644 --- a/bindings/pydrake/multibody/test/plant_test.py +++ b/bindings/pydrake/multibody/test/plant_test.py @@ -328,7 +328,7 @@ def check_repr(element, expected): self._test_joint_api(T, shoulder) check_repr( shoulder, - "") np.testing.assert_array_equal( shoulder.position_lower_limits(), [-np.inf]) @@ -352,12 +352,12 @@ def check_repr(element, expected): self.assertEqual(len(plant.GetBodyIndices(model_instance)), 2) check_repr( link1, - "") + "") self._test_frame_api(T, plant.GetFrameByName(name="Link1")) link1_frame = plant.GetFrameByName(name="Link1") check_repr( link1_frame, - "") + "") self.assertIs( link1_frame, plant.GetFrameByName(name="Link1", model_instance=model_instance)) @@ -400,7 +400,7 @@ def check_repr(element, expected): plant.GetJointActuatorIndices(model_instance=model_instance)) check_repr( joint_actuator, - "") self.assertIsInstance( plant.get_frame(frame_index=world_frame_index()), Frame) @@ -806,7 +806,7 @@ def test_multibody_force_element(self, T): if T == float: self.assertEqual( repr(linear_spring), - "") + "") revolute_joint = plant.AddJoint(RevoluteJoint_[T]( name="revolve_joint", frame_on_parent=body_a.body_frame(), frame_on_child=body_b.body_frame(), axis=[0, 0, 1], diff --git a/bindings/pydrake/systems/scalar_conversion.py b/bindings/pydrake/systems/scalar_conversion.py index 02b403745527..5a3bcebd0ac2 100644 --- a/bindings/pydrake/systems/scalar_conversion.py +++ b/bindings/pydrake/systems/scalar_conversion.py @@ -145,8 +145,8 @@ def wrapped(param): return decorator - def _on_add(self, param, cls): - TemplateClass._on_add(self, param, cls) + def _on_add(self, param, cls, skip_rename): + TemplateClass._on_add(self, param, cls, skip_rename) T, = param # Check that the user has not defined `__init__`, and has defined diff --git a/bindings/pydrake/systems/test/value_test.py b/bindings/pydrake/systems/test/value_test.py index 2246161e5980..6301d0555a32 100644 --- a/bindings/pydrake/systems/test/value_test.py +++ b/bindings/pydrake/systems/test/value_test.py @@ -119,18 +119,17 @@ def test_str_and_repr(self): vector_f = [1.] value_f = BasicVector_[float](vector_f) self.assertEqual(str(value_f), "[1.0]") - self.assertEqual(repr(value_f), "BasicVector_[float]([1.0])") + self.assertEqual(repr(value_f), "BasicVector([1.0])") # Check repr() invariant. self.assert_basic_vector_equal(value_f, eval(repr(value_f))) # - Empty. value_f_empty = BasicVector_[float]([]) self.assertEqual(str(value_f_empty), "[]") - self.assertEqual(repr(value_f_empty), "BasicVector_[float]([])") + self.assertEqual(repr(value_f_empty), "BasicVector([])") # - Multiple values. value_f_multi = BasicVector_[float]([1., 2.]) self.assertEqual(str(value_f_multi), "[1.0, 2.0]") - self.assertEqual( - repr(value_f_multi), "BasicVector_[float]([1.0, 2.0])") + self.assertEqual(repr(value_f_multi), "BasicVector([1.0, 2.0])") # TODO(eric.cousineau): Make repr() for AutoDiffXd and Expression be # semi-usable. # T=AutoDiffXd diff --git a/doc/pydrake/pydrake_sphinx_extension.py b/doc/pydrake/pydrake_sphinx_extension.py index 5dc830145259..47f77b3cbe12 100644 --- a/doc/pydrake/pydrake_sphinx_extension.py +++ b/doc/pydrake/pydrake_sphinx_extension.py @@ -137,26 +137,16 @@ class TemplateDocumenter(autodoc.ModuleLevelDocumenter): # Take priority over attributes. priority = 1 + autodoc.AttributeDocumenter.priority - option_spec = { - 'show-all-instantiations': autodoc.bool_option, - } - # Permit propagation of class-specific properties. - option_spec.update(autodoc.ClassDocumenter.option_spec) - @classmethod def can_document_member(cls, member, membername, isattr, parent): """Overrides base to check for template objects.""" return isinstance(member, TemplateBase) def get_object_members(self, want_all): - """Overrides base to return instantiations from templates.""" - members = [] - for param in self.object.param_list: - instantiation = self.object[param] - members.append((instantiation.__name__, instantiation)) - if not self.options.show_all_instantiations: - break - return False, members + """Overrides base; we shouldn't show any details beyond the list of + instantiations. + """ + return False, [] def check_module(self): """Overrides base to show template objects given the correct module.""" @@ -199,8 +189,7 @@ def tpl_attrgetter(obj, name, *defargs): """ # N.B. Rather than try to evaluate parameters from the string, we instead # match based on instantiation name. - if "[" in name: - assert name.endswith(']'), name + if isinstance(obj, TemplateBase) and name[0] != "_": for param in obj.param_list: inst = obj[param] if inst.__name__ == name: