8000 BUG: infinite recursion in str of 0d subclasses · hanjohn/numpy@f8ccee9 · GitHub
[go: up one dir, main page]

Skip to content

Commit f8ccee9

Browse files
ahaldanehanjohn
authored andcommitted
BUG: infinite recursion in str of 0d subclasses
Fixes numpy#10360
1 parent b5f6767 commit f8ccee9

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

numpy/core/arrayprint.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,14 +468,17 @@ def wrapper(self, *args, **kwargs):
468468
# gracefully handle recursive calls, when object arrays contain themselves
469469
@_recursive_guard()
470470
def _array2string(a, options, separator=' ', prefix=""):
471-
# The formatter __init__s cannot deal with subclasses yet
472-
data = asarray(a)
471+
# The formatter __init__s in _get_format_function cannot deal with
472+
# subclasses yet, and we also need to avoid recursion issues in
473+
# _formatArray with subclasses which return 0d arrays in place of scalars
474+
a = asarray(a)
473475

474476
if a.size > options['threshold']:
475477
summary_insert = "..."
476-
data = _leading_trailing(data, options['edgeitems'])
478+
data = _leading_trailing(a, options['edgeitems'])
477479
else:
478480
summary_insert = ""
481+
data = a
479482

480483
# find the right formatting function for the array
481484
format_function = _get_format_function(data, **options)
@@ -501,7 +504,7 @@ def array2string(a, max_line_width=None, precision=None,
501504
502505
Parameters
503506
----------
504-
a : ndarray
507+
a : array_like
505508
Input array.
506509
max_line_width : int, optional
507510
The maximum number of columns the string should span. Newline
@@ -763,7 +766,7 @@ def recurser(index, hanging_indent, curr_width):
763766

764767
if show_summary:
765768
if legacy == '1.13':
766-
# trailing space, fixed number of newlines, and fixed separator
769+
# trailing space, fixed nbr of newlines, and fixed separator
767770
s += hanging_indent + summary_insert + ", \n"
768771
else:
769772
s += hanging_indent + summary_insert + line_sep
@@ -1413,6 +1416,8 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):
14131416

14141417
return arr_str + spacer + dtype_str
14151418

1419+
_guarded_str = _recursive_guard()(str)
1420+
14161421
def array_str(a, max_line_width=None, precision=None, suppress_small=None):
14171422
"""
14181423
Return a string representation of the data in an array.
@@ -1455,7 +1460,10 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None):
14551460
# so floats are not truncated by `precision`, and strings are not wrapped
14561461
# in quotes. So we return the str of the scalar value.
14571462
if a.shape == ():
1458-
return str(a[()])
1463+
# obtain a scalar and call str on it, avoiding problems for subclasses
1464+
# for which indexing with () returns a 0d instead of a scalar by using
1465+
# ndarray's getindex. Also guard against recursive 0d object arrays.
1466+
return _guarded_str(np.ndarray.__getitem__(a, ()))
14591467

14601468
return array2string(a, max_line_width, precision, suppress_small, ' ', "")
14611469

numpy/core/tests/test_arrayprint.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,55 @@ class sub(np.ndarray): pass
3434
" [(1,), (1,)]], dtype=[('a', '<i4')])"
3535
)
3636

37+
def test_0d_object_subclass(self):
38+
# make sure that subclasses which return 0ds instead
39+
# of scalars don't cause infinite recursion in str
40+
class sub(np.ndarray):
41+
def __new__(cls, inp):
42+
obj = np.asarray(inp).view(cls)
43+
return obj
44+
45+
def __getitem__(self, ind):
46+
ret = super(sub, self).__getitem__(ind)
47+
return sub(ret)
48+
49+
x = sub(1)
50+
assert_equal(repr(x), 'sub(1)')
51+
assert_equal(str(x), '1')
52+
53+
x = sub([1, 1])
54+
assert_equal(repr(x), 'sub([1, 1])')
55+
assert_equal(str(x), '[1 1]')
56+
57+
# check it works properly with object arrays too
58+
x = sub(None)
59+
assert_equal(repr(x), 'sub(None, dtype=object)')
60+
assert_equal(str(x), 'None')
61+
62+
# plus recursive object arrays (even depth > 1)
63+
y = sub(None)
64+
x[()] = y
65+
y[()] = x
66+
assert_equal(repr(x),
67+
'sub(sub(sub(..., dtype=object), dtype=object), dtype=object)')
68+
assert_equal(str(x), '...')
69+
70+
# nested 0d-subclass-object
71+
x = sub(None)
72+
x[()] = sub(None)
73+
assert_equal(repr(x), 'sub(sub(None, dtype=object), dtype=object)')
74+
assert_equal(str(x), 'None')
75+
76+
# test that object + subclass is OK:
77+
x = sub([None, None])
78+
assert_equal(repr(x), 'sub([None, None], dtype=object)')
79+
assert_equal(str(x), '[None None]')
80+
81+
x = sub([None, sub([None, None])])
82+
assert_equal(repr(x),
83+
'sub([None, sub([None, None], dtype=object)], dtype=object)')
84+
assert_equal(str(x), '[None sub([None, None], dtype=object)]')
85+
3786
def test_self_containing(self):
3887
arr0d = np.array(None)
3988
arr0d[()] = arr0d

0 commit comments

Comments
 (0)
0