Skip to content

Commit

Permalink
BUG: infinite recursion in str of 0d subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
ahaldane committed Feb 11, 2018
1 parent c768b9d commit ff12de3
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 6 deletions.
20 changes: 14 additions & 6 deletions numpy/core/arrayprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,14 +468,17 @@ def wrapper(self, *args, **kwargs):
# gracefully handle recursive calls, when object arrays contain themselves
@_recursive_guard()
def _array2string(a, options, separator=' ', prefix=""):
# The formatter __init__s cannot deal with subclasses yet
data = asarray(a)
# The formatter __init__s in _get_format_function cannot deal with
# subclasses yet, and we also need to avoid recursion issues in
# _formatArray with subclasses which return 0d arrays in place of scalars
a = asarray(a)

if a.size > options['threshold']:
summary_insert = "..."
data = _leading_trailing(data, options['edgeitems'])
data = _leading_trailing(a, options['edgeitems'])
else:
summary_insert = ""
data = a

# find the right formatting function for the array
format_function = _get_format_function(data, **options)
Expand All @@ -501,7 +504,7 @@ def array2string(a, max_line_width=None, precision=None,
Parameters
----------
a : ndarray
a : array_like
Input array.
max_line_width : int, optional
The maximum number of columns the string should span. Newline
Expand Down Expand Up @@ -763,7 +766,7 @@ def recurser(index, hanging_indent, curr_width):

if show_summary:
if legacy == '1.13':
# trailing space, fixed number of newlines, and fixed separator
# trailing space, fixed nbr of newlines, and fixed separator
s += hanging_indent + summary_insert + ", \n"
else:
s += hanging_indent + summary_insert + line_sep
Expand Down Expand Up @@ -1413,6 +1416,8 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):

return arr_str + spacer + dtype_str

_guarded_str = _recursive_guard()(str)

def array_str(a, max_line_width=None, precision=None, suppress_small=None):
"""
Return a string representation of the data in an array.
Expand Down Expand Up @@ -1455,7 +1460,10 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None):
# so floats are not truncated by `precision`, and strings are not wrapped
# in quotes. So we return the str of the scalar value.
if a.shape == ():
return str(a[()])
# obtain a scalar and call str on it, avoiding problems for subclasses
# for which indexing with () returns a 0d instead of a scalar by using
# ndarray's getindex. Also guard against recursive 0d object arrays.
return _guarded_str(np.ndarray.__getitem__(a, ()))

return array2string(a, max_line_width, precision, suppress_small, ' ', "")

Expand Down
49 changes: 49 additions & 0 deletions numpy/core/tests/test_arrayprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,55 @@ class sub(np.ndarray): pass
" [(1,), (1,)]], dtype=[('a', '<i4')])"
)

def test_0d_object_subclass(self):
# make sure that subclasses which return 0ds instead
# of scalars don't cause infinite recursion in str
class sub(np.ndarray):
def __new__(cls, inp):
obj = np.asarray(inp).view(cls)
return obj

def __getitem__(self, ind):
ret = super(sub, self).__getitem__(ind)
return sub(ret)

x = sub(1)
assert_equal(repr(x), 'sub(1)')
assert_equal(str(x), '1')

x = sub([1, 1])
assert_equal(repr(x), 'sub([1, 1])')
assert_equal(str(x), '[1 1]')

# check it works properly with object arrays too
x = sub(None)
assert_equal(repr(x), 'sub(None, dtype=object)')
assert_equal(str(x), 'None')

# plus recursive object arrays (even depth > 1)
y = sub(None)
x[()] = y
y[()] = x
assert_equal(repr(x),
'sub(sub(sub(..., dtype=object), dtype=object), dtype=object)')
assert_equal(str(x), '...')

# nested 0d-subclass-object
x = sub(None)
x[()] = sub(None)
assert_equal(repr(x), 'sub(sub(None, dtype=object), dtype=object)')
assert_equal(str(x), 'None')

# test that object + subclass is OK:
x = sub([None, None])
assert_equal(repr(x), 'sub([None, None], dtype=object)')
assert_equal(str(x), '[None None]')

x = sub([None, sub([None, None])])
assert_equal(repr(x),
'sub([None, sub([None, None], dtype=object)], dtype=object)')
assert_equal(str(x), '[None sub([None, None], dtype=object)]')

def test_self_containing(self):
arr0d = np.array(None)
arr0d[()] = arr0d
Expand Down

0 comments on commit ff12de3

Please sign in to comment.