forked from sammy-tri/drake
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request RobotLocomotion#11386 from EricCousineau-TRI/issue…
…/11385_initial py: Add `numpy_compare` testing utilities
- Loading branch information
Showing
9 changed files
with
647 additions
and
428 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from pydrake.autodiffutils import AutoDiffXd | ||
import pydrake.common.test_utilities.numpy_compare as npc | ||
from pydrake.symbolic import Expression, Variable | ||
|
||
|
||
class Custom(object): | ||
def __init__(self, value): | ||
assert isinstance(value, str) | ||
self._str = value | ||
|
||
def __str__(self): | ||
return self._str | ||
|
||
def assert_eq(self, other): | ||
assert self._str == other._str, (self, other) | ||
|
||
def assert_ne(self, other): | ||
if self._str == other._str: | ||
raise npc._UnwantedEquality(str((self, other))) | ||
|
||
|
||
# Hack into private API to register custom comparisons. | ||
registry = npc._registry | ||
registry.register_to_float(Custom, lambda x: float(str(x))) | ||
registry.register_comparator( | ||
Custom, Custom, Custom.assert_eq, Custom.assert_ne) | ||
registry.register_comparator(Custom, str, npc._str_eq, npc._str_ne) | ||
|
||
|
||
class TestNumpyCompareSimple(unittest.TestCase): | ||
def test_to_float(self): | ||
# Scalar. | ||
xi = 1 | ||
xf = npc.to_float(xi) | ||
self.assertEqual(xf.dtype, float) | ||
self.assertEqual(xi, xf) | ||
# Array. | ||
Xi = np.array([1, 2, 3], np.int) | ||
Xf = npc.to_float(Xi) | ||
self.assertEqual(Xf.dtype, float) | ||
np.testing.assert_equal(Xi, Xf) | ||
# Custom. | ||
a = Custom("1.") | ||
b = Custom("2.") | ||
self.assertEqual(npc.to_float(a), 1.) | ||
A = np.array([a, b]) | ||
np.testing.assert_equal(npc.to_float(A), [1., 2.]) | ||
|
||
def test_asserts_builtin(self): | ||
a = 1. | ||
b = 0. | ||
# Scalar. | ||
npc.assert_equal(a, a) | ||
with self.assertRaises(AssertionError): | ||
npc.assert_equal(a, b) | ||
npc.assert_not_equal(a, b) | ||
with self.assertRaises(AssertionError): | ||
npc.assert_not_equal(a, a) | ||
# Array. | ||
A = np.array([a, a]) | ||
C = np.array([1., 2.]) | ||
npc.assert_equal(A, a) | ||
npc.assert_equal(C, C) | ||
with self.assertRaises(AssertionError): | ||
npc.assert_equal(A, b) | ||
npc.assert_not_equal(A, A + [0, 0.1]) | ||
npc.assert_not_equal(A, b) | ||
with self.assertRaises(AssertionError): | ||
npc.assert_not_equal(C, C) | ||
|
||
def test_asserts_custom(self): | ||
a = Custom("a") | ||
b = Custom("b") | ||
# Scalar. | ||
npc.assert_equal(a, a) | ||
npc.assert_equal(a, "a") | ||
with self.assertRaises(AssertionError): | ||
npc.assert_equal(a, b) | ||
npc.assert_not_equal(a, b) | ||
with self.assertRaises(AssertionError): | ||
npc.assert_not_equal(a, a) | ||
with self.assertRaises(AssertionError): | ||
npc.assert_not_equal(a, "a") | ||
# Array. | ||
A = np.array([a, a]) | ||
C = np.array([Custom("c0"), Custom("c1")]) | ||
npc.assert_equal(A, a) | ||
npc.assert_equal(A, "a") | ||
npc.assert_equal(C, C) | ||
npc.assert_equal(C, ["c0", "c1"]) | ||
with self.assertRaises(AssertionError): | ||
npc.assert_equal(A, b) | ||
npc.assert_not_equal(A, [a, b]) | ||
npc.assert_not_equal(A, ["a", "b"]) | ||
with self.assertRaises(AssertionError): | ||
npc.assert_not_equal(C, C) | ||
|
||
def test_asserts_autodiff(self): | ||
# Test only scalar; other cases are handled by above test case. | ||
a = AutoDiffXd(1., [1., 0.]) | ||
b = AutoDiffXd(1., [0., 1.]) | ||
c = AutoDiffXd(2., [3., 4.]) | ||
npc.assert_equal(a, a) | ||
npc.assert_not_equal(a, b) | ||
npc.assert_not_equal(a, c) | ||
|
||
def test_asserts_symbolic(self): | ||
x = Variable("x") | ||
y = Variable("y") | ||
e = x + y | ||
npc.assert_equal(x, x) | ||
npc.assert_equal(x, "x") | ||
npc.assert_not_equal(x, y) | ||
npc.assert_equal(e, x + y) | ||
npc.assert_equal(e, "(x + y)") | ||
npc.assert_not_equal(e, x - y) | ||
npc.assert_not_equal(e, "(x - y)") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
163 changes: 163 additions & 0 deletions
163
bindings/pydrake/common/test_utilities/numpy_compare.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
""" | ||
Provides consistent utilities for comparing NumPy matrices and scalars. | ||
Prefer comparisons in the following order: | ||
- Methods from this module. | ||
- Methods from `np.testing.*`, if the dtypes are guaranteed to be NumPy | ||
builtins and the API definitely won't support other scalar types. | ||
""" | ||
|
||
# TODO(eric.cousineau): Make custom assert-vectorize which will output | ||
# coordinates and stuff. | ||
|
||
from collections import namedtuple | ||
from itertools import product | ||
|
||
import numpy as np | ||
|
||
|
||
class _UnwantedEquality(AssertionError): | ||
pass | ||
|
||
|
||
class _Registry(object): | ||
# Scalar comparator. | ||
# `assert_eq` will be vectorized; it should raise an assertion error upon | ||
# first inequality. | ||
# `assert_ne` will stay as scalar; this should raise `_UnwantedEquality` to | ||
# make intent explicit. | ||
# TODO(eric.cousineau): Add `assert_near` when it's necessary. | ||
AssertComparator = namedtuple( | ||
'AssertComparator', ['assert_eq', 'assert_ne']) | ||
|
||
def __init__(self): | ||
self._comparators = {} | ||
self._to_float = {} | ||
|
||
def register_comparator(self, cls_a, cls_b, assert_eq, assert_ne=None): | ||
key = (cls_a, cls_b) | ||
assert key not in self._comparators, key | ||
assert_eq = np.vectorize(assert_eq) | ||
self._comparators[key] = self.AssertComparator(assert_eq, assert_ne) | ||
|
||
def get_comparator_from_arrays(self, a, b): | ||
# Ensure all types are homogeneous. | ||
a_type, = {type(np.asarray(x).item()) for x in a.flat} | ||
b_type, = {type(np.asarray(x).item()) for x in b.flat} | ||
key = (a_type, b_type) | ||
return self._comparators[key] | ||
|
||
def register_to_float(self, cls, func): | ||
assert cls not in self._to_float, cls | ||
self._to_float[cls] = func | ||
|
||
def get_to_float(self, cls): | ||
return self._to_float[cls] | ||
|
||
|
||
@np.vectorize | ||
def to_float(x): | ||
"""Converts scalar or array to floats.""" | ||
x = np.asarray(x) | ||
if x.dtype == object: | ||
x = x.item() | ||
cls = type(x) | ||
return _registry.get_to_float(cls)(x) | ||
else: | ||
return np.float64(x) | ||
|
||
|
||
def assert_equal(a, b): | ||
"""Compare scalars or arrays directly, requiring equality.""" | ||
a, b = map(np.asarray, (a, b)) | ||
if a.size == 0 and b.size == 0: | ||
return | ||
if a.dtype != object and b.dtype != object: | ||
np.testing.assert_equal(a, b) | ||
else: | ||
_registry.get_comparator_from_arrays(a, b).assert_eq(a, b) | ||
|
||
|
||
def _raw_ne(a, b): | ||
if a == b: | ||
raise _UnwantedEquality(str((a, b))) | ||
|
||
|
||
def assert_not_equal(a, b): | ||
"""Compare scalars or arrays directly, requiring inequality.""" | ||
a, b = map(np.asarray, (a, b)) | ||
assert not (a.size == 0 and b.size == 0) | ||
if a.dtype != object and b.dtype != object: | ||
assert_ne = _raw_ne | ||
else: | ||
assert_ne = _registry.get_comparator_from_arrays(a, b).assert_ne | ||
# For this to fail, all items must have failed. | ||
br = np.broadcast(a, b) | ||
errs = [] | ||
for ai, bi in br: | ||
e = None | ||
try: | ||
assert_ne(ai, bi) | ||
except _UnwantedEquality as e: | ||
errs.append(str(e)) | ||
all_equal = len(errs) == br.size | ||
if all_equal: | ||
raise AssertionError("Unwanted equality: {}".format(errs)) | ||
|
||
|
||
def _str_eq(a, b): | ||
# b is a string, a is to be converted. | ||
a = str(a) | ||
assert a == b, (a, b) | ||
|
||
|
||
def _str_ne(a, b): | ||
# b is a string, a is to be converted. | ||
a = str(a) | ||
if a == b: | ||
raise _UnwantedEquality(str((a, b))) | ||
|
||
|
||
def _register_autodiff(): | ||
from pydrake.autodiffutils import AutoDiffXd | ||
|
||
def autodiff_eq(a, b): | ||
assert a.value() == b.value(), (a.value(), b.value()) | ||
np.testing.assert_equal(a.derivatives(), b.derivatives()) | ||
|
||
def autodiff_ne(a, b): | ||
if (a.value() == b.value() and | ||
(a.derivatives() == b.derivatives()).all()): | ||
raise _UnwantedEquality(str(a.value(), b.derivatives())) | ||
|
||
_registry.register_to_float(AutoDiffXd, AutoDiffXd.value) | ||
_registry.register_comparator( | ||
AutoDiffXd, AutoDiffXd, autodiff_eq, autodiff_ne) | ||
|
||
|
||
def _register_symbolic(): | ||
from pydrake.symbolic import ( | ||
Expression, Formula, Monomial, Polynomial, Variable) | ||
|
||
def sym_struct_eq(a, b): | ||
assert a.EqualTo(b), (a, b) | ||
|
||
def sym_struct_ne(a, b): | ||
assert not a.EqualTo(b), (a, b) | ||
|
||
_registry.register_to_float(Expression, Expression.Evaluate) | ||
_registry.register_comparator(Formula, str, _str_eq, _str_ne) | ||
lhs_types = [Variable, Expression, Polynomial, Monomial] | ||
rhs_types = lhs_types + [float] | ||
for lhs_type in lhs_types: | ||
_registry.register_comparator(lhs_type, str, _str_eq, _str_ne) | ||
for lhs_type, rhs_type in product(lhs_types, rhs_types): | ||
_registry.register_comparator( | ||
lhs_type, rhs_type, sym_struct_eq, sym_struct_ne) | ||
|
||
|
||
# Globals. | ||
_registry = _Registry() | ||
_register_autodiff() | ||
_register_symbolic() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.