Skip to content

Commit

Permalink
Merge pull request RobotLocomotion#11386 from EricCousineau-TRI/issue…
Browse files Browse the repository at this point in the history
…/11385_initial

py: Add `numpy_compare` testing utilities
  • Loading branch information
EricCousineau-TRI authored May 7, 2019
2 parents c40ad7d + 9148ac3 commit 6ec656b
Show file tree
Hide file tree
Showing 9 changed files with 647 additions and 428 deletions.
2 changes: 2 additions & 0 deletions bindings/pydrake/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ drake_py_unittest(
":autodiffutils_py",
":autodiffutils_test_util_py",
":math_py",
"//bindings/pydrake/common/test_utilities:numpy_compare_py",
],
)

Expand Down Expand Up @@ -415,6 +416,7 @@ drake_py_unittest(
":algebra_test_util_py",
":symbolic_py",
"//bindings/pydrake/common:containers_py",
"//bindings/pydrake/common/test_utilities:numpy_compare_py",
],
)

Expand Down
7 changes: 7 additions & 0 deletions bindings/pydrake/common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,13 @@ drake_py_unittest(
],
)

drake_py_unittest(
name = "numpy_compare_test",
deps = [
"//bindings/pydrake/common/test_utilities:numpy_compare_py",
],
)

drake_py_unittest(
name = "cpp_const_test",
deps = [
Expand Down
121 changes: 121 additions & 0 deletions bindings/pydrake/common/test/numpy_compare_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import unittest

import numpy as np

from pydrake.autodiffutils import AutoDiffXd
import pydrake.common.test_utilities.numpy_compare as npc
from pydrake.symbolic import Expression, Variable


class Custom(object):
def __init__(self, value):
assert isinstance(value, str)
self._str = value

def __str__(self):
return self._str

def assert_eq(self, other):
assert self._str == other._str, (self, other)

def assert_ne(self, other):
if self._str == other._str:
raise npc._UnwantedEquality(str((self, other)))


# Hack into private API to register custom comparisons.
registry = npc._registry
registry.register_to_float(Custom, lambda x: float(str(x)))
registry.register_comparator(
Custom, Custom, Custom.assert_eq, Custom.assert_ne)
registry.register_comparator(Custom, str, npc._str_eq, npc._str_ne)


class TestNumpyCompareSimple(unittest.TestCase):
def test_to_float(self):
# Scalar.
xi = 1
xf = npc.to_float(xi)
self.assertEqual(xf.dtype, float)
self.assertEqual(xi, xf)
# Array.
Xi = np.array([1, 2, 3], np.int)
Xf = npc.to_float(Xi)
self.assertEqual(Xf.dtype, float)
np.testing.assert_equal(Xi, Xf)
# Custom.
a = Custom("1.")
b = Custom("2.")
self.assertEqual(npc.to_float(a), 1.)
A = np.array([a, b])
np.testing.assert_equal(npc.to_float(A), [1., 2.])

def test_asserts_builtin(self):
a = 1.
b = 0.
# Scalar.
npc.assert_equal(a, a)
with self.assertRaises(AssertionError):
npc.assert_equal(a, b)
npc.assert_not_equal(a, b)
with self.assertRaises(AssertionError):
npc.assert_not_equal(a, a)
# Array.
A = np.array([a, a])
C = np.array([1., 2.])
npc.assert_equal(A, a)
npc.assert_equal(C, C)
with self.assertRaises(AssertionError):
npc.assert_equal(A, b)
npc.assert_not_equal(A, A + [0, 0.1])
npc.assert_not_equal(A, b)
with self.assertRaises(AssertionError):
npc.assert_not_equal(C, C)

def test_asserts_custom(self):
a = Custom("a")
b = Custom("b")
# Scalar.
npc.assert_equal(a, a)
npc.assert_equal(a, "a")
with self.assertRaises(AssertionError):
npc.assert_equal(a, b)
npc.assert_not_equal(a, b)
with self.assertRaises(AssertionError):
npc.assert_not_equal(a, a)
with self.assertRaises(AssertionError):
npc.assert_not_equal(a, "a")
# Array.
A = np.array([a, a])
C = np.array([Custom("c0"), Custom("c1")])
npc.assert_equal(A, a)
npc.assert_equal(A, "a")
npc.assert_equal(C, C)
npc.assert_equal(C, ["c0", "c1"])
with self.assertRaises(AssertionError):
npc.assert_equal(A, b)
npc.assert_not_equal(A, [a, b])
npc.assert_not_equal(A, ["a", "b"])
with self.assertRaises(AssertionError):
npc.assert_not_equal(C, C)

def test_asserts_autodiff(self):
# Test only scalar; other cases are handled by above test case.
a = AutoDiffXd(1., [1., 0.])
b = AutoDiffXd(1., [0., 1.])
c = AutoDiffXd(2., [3., 4.])
npc.assert_equal(a, a)
npc.assert_not_equal(a, b)
npc.assert_not_equal(a, c)

def test_asserts_symbolic(self):
x = Variable("x")
y = Variable("y")
e = x + y
npc.assert_equal(x, x)
npc.assert_equal(x, "x")
npc.assert_not_equal(x, y)
npc.assert_equal(e, x + y)
npc.assert_equal(e, "(x + y)")
npc.assert_not_equal(e, x - y)
npc.assert_not_equal(e, "(x - y)")
10 changes: 10 additions & 0 deletions bindings/pydrake/common/test_utilities/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ drake_py_library(
deps = [":module_py"],
)

drake_py_library(
name = "numpy_compare_py",
srcs = ["numpy_compare.py"],
deps = [
":module_py",
"//bindings/pydrake:autodiffutils_py",
"//bindings/pydrake:symbolic_py",
],
)

# Package roll-up (for Bazel dependencies).
drake_py_library(
name = "test_utilities",
Expand Down
163 changes: 163 additions & 0 deletions bindings/pydrake/common/test_utilities/numpy_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""
Provides consistent utilities for comparing NumPy matrices and scalars.
Prefer comparisons in the following order:
- Methods from this module.
- Methods from `np.testing.*`, if the dtypes are guaranteed to be NumPy
builtins and the API definitely won't support other scalar types.
"""

# TODO(eric.cousineau): Make custom assert-vectorize which will output
# coordinates and stuff.

from collections import namedtuple
from itertools import product

import numpy as np


class _UnwantedEquality(AssertionError):
pass


class _Registry(object):
# Scalar comparator.
# `assert_eq` will be vectorized; it should raise an assertion error upon
# first inequality.
# `assert_ne` will stay as scalar; this should raise `_UnwantedEquality` to
# make intent explicit.
# TODO(eric.cousineau): Add `assert_near` when it's necessary.
AssertComparator = namedtuple(
'AssertComparator', ['assert_eq', 'assert_ne'])

def __init__(self):
self._comparators = {}
self._to_float = {}

def register_comparator(self, cls_a, cls_b, assert_eq, assert_ne=None):
key = (cls_a, cls_b)
assert key not in self._comparators, key
assert_eq = np.vectorize(assert_eq)
self._comparators[key] = self.AssertComparator(assert_eq, assert_ne)

def get_comparator_from_arrays(self, a, b):
# Ensure all types are homogeneous.
a_type, = {type(np.asarray(x).item()) for x in a.flat}
b_type, = {type(np.asarray(x).item()) for x in b.flat}
key = (a_type, b_type)
return self._comparators[key]

def register_to_float(self, cls, func):
assert cls not in self._to_float, cls
self._to_float[cls] = func

def get_to_float(self, cls):
return self._to_float[cls]


@np.vectorize
def to_float(x):
"""Converts scalar or array to floats."""
x = np.asarray(x)
if x.dtype == object:
x = x.item()
cls = type(x)
return _registry.get_to_float(cls)(x)
else:
return np.float64(x)


def assert_equal(a, b):
"""Compare scalars or arrays directly, requiring equality."""
a, b = map(np.asarray, (a, b))
if a.size == 0 and b.size == 0:
return
if a.dtype != object and b.dtype != object:
np.testing.assert_equal(a, b)
else:
_registry.get_comparator_from_arrays(a, b).assert_eq(a, b)


def _raw_ne(a, b):
if a == b:
raise _UnwantedEquality(str((a, b)))


def assert_not_equal(a, b):
"""Compare scalars or arrays directly, requiring inequality."""
a, b = map(np.asarray, (a, b))
assert not (a.size == 0 and b.size == 0)
if a.dtype != object and b.dtype != object:
assert_ne = _raw_ne
else:
assert_ne = _registry.get_comparator_from_arrays(a, b).assert_ne
# For this to fail, all items must have failed.
br = np.broadcast(a, b)
errs = []
for ai, bi in br:
e = None
try:
assert_ne(ai, bi)
except _UnwantedEquality as e:
errs.append(str(e))
all_equal = len(errs) == br.size
if all_equal:
raise AssertionError("Unwanted equality: {}".format(errs))


def _str_eq(a, b):
# b is a string, a is to be converted.
a = str(a)
assert a == b, (a, b)


def _str_ne(a, b):
# b is a string, a is to be converted.
a = str(a)
if a == b:
raise _UnwantedEquality(str((a, b)))


def _register_autodiff():
from pydrake.autodiffutils import AutoDiffXd

def autodiff_eq(a, b):
assert a.value() == b.value(), (a.value(), b.value())
np.testing.assert_equal(a.derivatives(), b.derivatives())

def autodiff_ne(a, b):
if (a.value() == b.value() and
(a.derivatives() == b.derivatives()).all()):
raise _UnwantedEquality(str(a.value(), b.derivatives()))

_registry.register_to_float(AutoDiffXd, AutoDiffXd.value)
_registry.register_comparator(
AutoDiffXd, AutoDiffXd, autodiff_eq, autodiff_ne)


def _register_symbolic():
from pydrake.symbolic import (
Expression, Formula, Monomial, Polynomial, Variable)

def sym_struct_eq(a, b):
assert a.EqualTo(b), (a, b)

def sym_struct_ne(a, b):
assert not a.EqualTo(b), (a, b)

_registry.register_to_float(Expression, Expression.Evaluate)
_registry.register_comparator(Formula, str, _str_eq, _str_ne)
lhs_types = [Variable, Expression, Polynomial, Monomial]
rhs_types = lhs_types + [float]
for lhs_type in lhs_types:
_registry.register_comparator(lhs_type, str, _str_eq, _str_ne)
for lhs_type, rhs_type in product(lhs_types, rhs_types):
_registry.register_comparator(
lhs_type, rhs_type, sym_struct_eq, sym_struct_ne)


# Globals.
_registry = _Registry()
_register_autodiff()
_register_symbolic()
9 changes: 9 additions & 0 deletions bindings/pydrake/pydrake_doxygen.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ used in Python.
For binding functions, methods, properties, and classes, docstrings should be
provided. These should be provided as described @ref PydrakeDoc "here".
## Testing
In general, since the Python bindings wrap tested C++ code, you do not (and
should not) repeat intricate testing logic done in C++. Instead, ensure you
exercise the Pythonic portion of the API, using kwargs when appropriate.
When testing the values of NumPy matrices, please review the documentation in
`pydrake.common.test_utilities.numpy_compare` for guidance.
## Target Conventions
### Names
Expand Down
Loading

0 comments on commit 6ec656b

Please sign in to comment.