8000 MAINT: ndarray.__repr__ should not rely on __array_function__ · shoyer/numpy@d991702 · GitHub
[go: up one dir, main page]

Skip to content

Commit d991702

Browse files
committed
MAINT: ndarray.__repr__ should not rely on __array_function__
``ndarray.__repr__`` and ``ndarray.__str__`` should not rely upon ``__array_function__`` internally, so they are still well defined on subclasses even if ``array_repr`` and ``array_str`` are not implemented. Fixes numpygh-12162
1 parent 2c4c93a commit d991702

File tree

2 files changed

+91
-53
lines changed

2 files changed

+91
-53
lines changed

numpy/core/arrayprint.py

Lines changed: 76 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,6 +1381,53 @@ def dtype_short_repr(dtype):
13811381
return typename
13821382

13831383

1384+
def _array_repr_implementation(
1385+
arr, max_line_width=None, precision=None, suppress_small=None,
1386+
array2string=array2string):
1387+
"""Internal version of array_repr() that allows overriding array2string."""
1388+
if max_line_width is None:
1389+
max_line_width = _format_options['linewidth']
1390+
1391+
if type(arr) is not ndarray:
1392+
class_name = type(arr).__name__
1393+
else:
1394+
class_name = "array"
1395+
1396+
skipdtype = dtype_is_implied(arr.dtype) and arr.size > 0
1397+
1398+
prefix = class_name + "("
1399+
suffix = ")" if skipdtype else ","
1400+
1401+
if (_format_options['legacy'] == '1.13' and
1402+
arr.shape == () and not arr.dtype.names):
1403+
lst = repr(arr.item())
1404+
elif arr.size > 0 or arr.shape == (0,):
1405+
lst = array2string(arr, max_line_width, precision, suppress_small,
1406+
', ', prefix, suffix=suffix)
1407+
else: # show zero-length shape unless it is (0,)
1408+
lst = "[], shape=%s" % (repr(arr.shape),)
1409+
1410+
arr_str = prefix + lst + suffix
1411+
1412+
if skipdtype:
1413+
return arr_str
1414+
1415+
dtype_str = "dtype={})".format(dtype_short_repr(arr.dtype))
1416+
1417+
# compute whether we should put dtype on a new line: Do so if adding the
1418+
# dtype would extend the last line past max_line_width.
< 8000 code>1419+
# Note: This line gives the correct result even when rfind returns -1.
1420+
last_line_len = len(arr_str) - (arr_str.rfind('\n') + 1)
1421+
spacer = " "
1422+
if _format_options['legacy'] == '1.13':
1423+
if issubclass(arr.dtype.type, flexible):
1424+
spacer = '\n' + ' '*len(class_name + "(")
1425+
elif last_line_len + len(dtype_str) + 1 > max_line_width:
1426+
spacer = '\n' + ' '*len(class_name + "(")
1427+
1428+
return arr_str + spacer + dtype_str
1429+
1430+
13841431
def _array_repr_dispatcher(
13851432
arr, max_line_width=None, precision=None, suppress_small=None):
13861433
return (arr,)
@@ -1429,50 +1476,31 @@ def array_repr(arr, max_line_width=None, precision=None, suppress_small=None):
14291476
'array([ 0.000001, 0. , 2. , 3. ])'
14301477
14311478
"""
1432-
if max_line_width is None:
1433-
max_line_width = _format_options['linewidth']
1479+
return _array_repr_implementation(
1480+
arr, max_line_width, precision, suppress_small)
14341481

1435-
if type(arr) is not ndarray:
1436-
class_name = type(arr).__name__
1437-
else:
1438-
class_name = "array"
14391482

1440-
skipdtype = dtype_is_implied(arr.dtype) and arr.size > 0
1483+
_guarded_str = _recursive_guard()(str)
14411484

1442-
prefix = class_name + "("
1443-
suffix = ")" if skipdtype else ","
14441485

1486+
def _array_str_implementation(
1487+
a, max_line_width=None, precision=None, suppress_small=None,
1488+
array2string=array2string):
1489+
"""Internal version of array_str() that allows overriding array2string."""
14451490
if (_format_options['legacy'] == '1.13' and
1446-
arr.shape == () and not arr.dtype.names):
1447-
lst = repr(arr.item())
1448-
elif arr.size > 0 or arr.shape == (0,):
1449-
lst = array2string(arr, max_line_width, precision, suppress_small,
1450-
', ', prefix, suffix=suffix)
1451-
else: # show zero-length shape unless it is (0,)
1452-
lst = "[], shape=%s" % (repr(arr.shape),)
1453-
1454-
arr_str = prefix + lst + suffix
1455-
1456-
if skipdtype:
1457-
return arr_str
1458-
1459-
dtype_str = "dtype={})".format(dtype_short_repr(arr.dtype))
1460-
1461-
# compute whether we should put dtype on a new line: Do so if adding the
1462-
# dtype would extend the last line past max_line_width.
1463-
# Note: This line gives the correct result even when rfind returns -1.
1464-
last_line_len = len(arr_str) - (arr_str.rfind('\n') + 1)
1465-
spacer = " "
1466-
if _format_options['legacy'] == '1.13':
1467-
if issubclass(arr.dtype.type, flexible):
1468-
spacer = '\n' + ' '*len(class_name + "(")
1469-
elif last_line_len + len(dtype_str) + 1 > max_line_width:
1470-
spacer = '\n' + ' '*len(class_name + "(")
1471-
1472-
return arr_str + spacer + dtype_str
1491+
a.shape == () and not a.dtype.names):
1492+
return str(a.item())
14731493

1494+
# the str of 0d arrays is a special case: It should appear like a scalar,
1495+
# so floats are not truncated by `precision`, and strings are not wrapped
1496+
# in quotes. So we return the str of the scalar value.
1497+
if a.shape == ():
1498+
# obtain a scalar and call str on it, avoiding problems for subclasses
1499+
# for which indexing with () returns a 0d instead of a scalar by using
1500+
# ndarray's getindex. Also guard against recursive 0d object arrays.
1501+
return _guarded_str(np.ndarray.__getitem__(a, ()))
14741502

1475-
_guarded_str = _recursive_guard()(str)
1503+
return array2string(a, max_line_width, precision, suppress_small, ' ', "")
14761504

14771505

14781506
def _array_str_dispatcher(
@@ -1515,20 +1543,15 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None):
15151543
'[0 1 2]'
15161544
15171545
"""
1518-
if (_format_options['legacy'] == '1.13' and
1519-
a.shape == () and not a.dtype.names):
1520-
return str(a.item())
1546+
return _array_str_implementation(
1547+
a, max_line_width, precision, suppress_small)
15211548

1522-
# the str of 0d arrays is a special case: It should appear like a scalar,
1523-
# so floats are not truncated by `precision`, and strings are not wrapped
1524-
# in quotes. So we return the str of the scalar value.
1525-
if a.shape == ():
1526-
# obtain a scalar and call str on it, avoiding problems for subclasses
1527-
# for which indexing with () returns a 0d instead of a scalar by using
1528-
# ndarray's getindex. Also guard against recursive 0d object arrays.
1529-
return _guarded_str(np.ndarray.__getitem__(a, ()))
15301549

1531-
return array2string(a, max_line_width, precision, suppress_small, ' ', "")
1550+
_default_array_str = functools.partial(_array_str_implementation,
1551+
array2string=array2string.__wrapped__)
10000 1552+
_default_array_repr = functools.partial(_array_repr_implementation,
1553+
array2string=array2string.__wrapped__)
1554+
15321555

15331556
def set_string_function(f, repr=True):
15341557
"""
@@ -1583,11 +1606,11 @@ def set_string_function(f, repr=True):
15831606
"""
15841607
if f is None:
15851608
if repr:
1586-
return multiarray.set_string_function(array_repr, 1)
1609+
return multiarray.set_string_function(_default_array_repr, 1)
15871610
else:
1588-
return multiarray.set_string_function(array_str, 0)
1611+
return multiarray.set_string_function(_default_array_str, 0)
15891612
else:
15901613
return multiarray.set_string_function(f, repr)
15911614

1592-
set_string_function(array_str, 0)
1593-
set_string_function(array_repr, 1)
1615+
set_string_function(_default_array_str, 0)
1616+
set_string_function(_default_array_repr, 1)

numpy/core/tests/test_overrides.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,3 +304,18 @@ def func(array):
304304

305305
with assert_raises_regex(TypeError, 'no implementation found'):
306306
func(MyArray())
307+
308+
309+
class TestNDArrayMethods(object):
310+
311+
def test_repr(self):
312+
# gh-12162: should still be defined even if __array_function__ doesn't
313+
# implement np.array_repr()
314+
315+
class MyArray(np.ndarray):
316+
def __array_function__(*args, **kwargs):
317+
return NotImplemented
318+
319+
array = np.array(1).view(MyArray)
320+
assert_equal(repr(array), 'MyArray(1)')
321+
assert_equal(str(array), '1')

0 commit comments

Comments
 (0)
0