Skip to content

Commit

Permalink
Merge pull request numpy#6686 from pv/assert-fix
Browse files Browse the repository at this point in the history
BUG: testing: fix a bug in assert_string_equal
  • Loading branch information
charris committed Nov 14, 2015
2 parents f83d68b + eadc135 commit 8ae543c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
19 changes: 18 additions & 1 deletion numpy/testing/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
assert_array_almost_equal, build_err_msg, raises, assert_raises,
assert_warns, assert_no_warnings, assert_allclose, assert_approx_equal,
assert_array_almost_equal_nulp, assert_array_max_ulp,
clear_and_catch_warnings, run_module_suite
clear_and_catch_warnings, run_module_suite,
assert_string_equal
)
import unittest

Expand Down Expand Up @@ -715,6 +716,22 @@ def test_nan(self):
lambda: assert_array_max_ulp(nan, nzero,
maxulp=maxulp))

class TestStringEqual(unittest.TestCase):
def test_simple(self):
assert_string_equal("hello", "hello")
assert_string_equal("hello\nmultiline", "hello\nmultiline")

try:
assert_string_equal("foo\nbar", "hello\nbar")
except AssertionError as exc:
assert_equal(str(exc), "Differences in strings:\n- foo\n+ hello")
else:
raise AssertionError("exception not raised")

self.assertRaises(AssertionError,
lambda: assert_string_equal("foo", "hello"))


def assert_warn_len_equal(mod, n_in_context):
mod_warns = mod.__warningregistry__
# Python 3.4 appears to clear any pre-existing warnings of the same type,
Expand Down
11 changes: 6 additions & 5 deletions numpy/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,11 +1018,12 @@ def assert_string_equal(actual, desired):
if not d2.startswith('+ '):
raise AssertionError(repr(d2))
l.append(d2)
d3 = diff.pop(0)
if d3.startswith('? '):
l.append(d3)
else:
diff.insert(0, d3)
if diff:
d3 = diff.pop(0)
if d3.startswith('? '):
l.append(d3)
else:
diff.insert(0, d3)
if re.match(r'\A'+d2[2:]+r'\Z', d1[2:]):
continue
diff_list.extend(l)
Expand Down

0 comments on commit 8ae543c

Please sign in to comment.