8000 BUG,ENH: Fix internal ``__array_wrap__`` for direct calls (#27807) · numpy/numpy@16b210c · GitHub
[go: up one dir, main page]

Skip to content

Commit 16b210c

Browse files
authored
BUG,ENH: Fix internal __array_wrap__ for direct calls (#27807)
* BUG,ENH: Fix internal ``__array_wrap__`` for direct calls Since adding `return_scalar` as am argument, the array-wrap implementations were slightly wrong when that argument was actually passed and the function called directly. NumPy itself rarely (or never) did so for our builtin types now so that was not a problem within NumPy. Further, the scalar version was completely broken, converting to scalar even when such a conversion was impossible. As explained in the code. For array subclasses we NEVER want to convert to scalar by default. The subclass must make that choice explicitly. (There are plenty of tests for this behavior.) * BUG: Ensure cast to self in ndarray.__array_wrap__ and other review fixes
1 parent 4f589ff commit 16b210c

File tree

6 files changed

+112
-29
lines changed

6 files changed

+112
-29
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
* Calling ``__array_wrap__`` directly on NumPy arrays or scalars
2+
now does the right thing when ``return_scalar`` is passed
3+
(Added in NumPy 2). It is further safe now to call the scalar
4+
``__array_wrap__`` on a non-scalar result.

numpy/_core/src/multiarray/methods.c

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -888,28 +888,39 @@ array_finalizearray(PyArrayObject *self, PyObject *obj)
888888
}
889889

890890

891+
/*
892+
* Default `__array_wrap__` implementation.
893+
*
894+
* If `self` is not a base class, we always create a new view, even if
895+
* `return_scalar` is set. This way we preserve the (presumably important)
896+
* subclass information.
897+
* If the type is a base class array, we honor `return_scalar` and call
898+
* PyArray_Return to convert any array with ndim=0 to scalar.
899+
*
900+
* By default, do not return a scalar (because this was always the default).
901+
*/
891902
static PyObject *
892903
array_wraparray(PyArrayObject *self, PyObject *args)
893904
{
894905
PyArrayObject *arr;
895-
PyObject *obj;
906+
PyObject *UNUSED = NULL; /* for the context argument */
907+
int return_scalar = 0;
896908

897-
if (PyTuple_Size(args) < 1) {
898-
PyErr_SetString(PyExc_TypeError,
899-
"only accepts 1 argument");
900-
return NULL;
901-
}
902-
obj = PyTuple_GET_ITEM(args, 0);
903-
if (obj == NULL) {
909+
if (!PyArg_ParseTuple(args, "O!|OO&:__array_wrap__",
910+
&PyArray_Type, &arr, &UNUSED,
911+
&PyArray_OptionalBoolConverter, &return_scalar)) {
904912
return NULL;
905913
}
906-
if (!PyArray_Check(obj)) {
907-
PyErr_SetString(PyExc_TypeError,
908-
"can only be called with ndarray object");
909-
return NULL;
914+
915+
if (return_scalar && Py_TYPE(self) == &PyArray_Type && PyArray_NDIM(arr) == 0) {
916+
/* Strict scalar return here (but go via PyArray_Return anyway) */
917+
Py_INCREF(arr);
918+
return PyArray_Return(arr);
910919
}
911-
arr = (PyArrayObject *)obj;
912920

921+
/*
922+
* Return an array, but should ensure it has the type of self
923+
*/
913924
if (Py_TYPE(self) != Py_TYPE(arr)) {
914925
PyArray_Descr *dtype = PyArray_DESCR(arr);
915926
Py_INCREF(dtype);
@@ -919,7 +930,7 @@ array_wraparray(PyArrayObject *self, PyObject *args)
919930
PyArray_NDIM(arr),
920931
PyArray_DIMS(arr),
921932
PyArray_STRIDES(arr), PyArray_DATA(arr),
922-
PyArray_FLAGS(arr), (PyObject *)self, obj);
933+
PyArray_FLAGS(arr), (PyObject *)self, (PyObject *)arr);
923934
}
924935
else {
925936
/*

numpy/_core/src/multiarray/scalartypes.c.src

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "numpyos.h"
2626
#include "can_cast_table.h"
2727
#include "common.h"
28+
#include "conversion_utils.h"
2829
#include "flagsobject.h"
2930
#include "scalartypes.h"
3031
#include "_datetime.h"
@@ -2035,29 +2036,33 @@ gentype_getarray(PyObject *scalar, PyObject *args)
20352036
return ret;
20362037
}
20372038

2038-
static char doc_sc_wraparray[] = "sc.__array_wrap__(obj) return scalar from array";
2039+
static char doc_sc_wraparray[] = "__array_wrap__ implementation for scalar types";
20392040

2041+
/*
2042+
* __array_wrap__ for scalars, returning a scalar if possible.
2043+
* (note that NumPy itself may well never call this itself).
2044+
*/
20402045
static PyObject *
20412046
gentype_wraparray(PyObject *NPY_UNUSED(scalar), PyObject *args)
20422047
{
2043-
PyObject *obj;
20442048
PyArrayObject *arr;
2049+
PyObject *UNUSED = NULL; /* for the context argument */
2050+
/* return_scalar should be passed, but we're scalar, so return scalar by default */
2051+
int return_scalar = 1;
20452052

2046-
if (PyTuple_Size(args) < 1) {
2047-
PyErr_SetString(PyExc_TypeError,
2048-
"only accepts 1 argument.");
2053+
if (!PyArg_ParseTuple(args, "O!|OO&:__array_wrap__",
2054+
&PyArray_Type, &arr, &UNUSED,
2055+
&PyArray_OptionalBoolConverter, &return_scalar)) {
20492056
return NULL;
20502057
}
2051-
obj = PyTuple_GET_ITEM(args, 0);
2052-
if (!PyArray_Check(obj)) {
2053-
PyErr_SetString(PyExc_TypeError,
2054-
"can only be called with ndarray object");
2055-
return NULL;
2056-
}
2057-
arr = (PyArrayObject *)obj;
20582058

2059-
return PyArray_Scalar(PyArray_DATA(arr),
2060-
PyArray_DESCR(arr), (PyObject *)arr);
2059+
Py_INCREF(arr);
2060+
if (!return_scalar) {
2061+
return (PyObject *)arr;
2062+
}
2063+
else {
2064+
return PyArray_Return(arr);
2065+
}
20612066
}
20622067

20632068
/*

numpy/_core/tests/test_arrayobject.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,45 @@ def test_matrix_transpose_equals_swapaxes(shape):
3131
tgt = np.swapaxes(arr, num_of_axes - 2, num_of_axes - 1)
3232
mT = arr.mT
3333
assert_array_equal(tgt, mT)
34+
35+
36+
class MyArr(np.ndarray):
37+
def __array_wrap__(self, arr, context=None, return_scalar=None):
38+
return super().__array_wrap__(arr, context, return_scalar)
39+
40+
41+
class MyArrNoWrap(np.ndarray):
42+
pass
43+
44+
45+
@pytest.mark.parametrize("subclass_self", [np.ndarray, MyArr, MyArrNoWrap])
46+
@pytest.mark.parametrize("subclass_arr", [np.ndarray, MyArr, MyArrNoWrap])
47+
def test_array_wrap(subclass_self, subclass_arr):
48+
# NumPy should allow `__array_wrap__` to be called on arrays, it's logic
49+
# is designed in a way that:
50+
#
51+
# * Subclasses never return scalars by default (to preserve their
52+
# information). They can choose to if they wish.
53+
# * NumPy returns scalars, if `return_scalar` is passed as True to allow
54+
# manual calls to `arr.__array_wrap__` to do the right thing.
55+
# * The type of the input should be ignored (it should be a base-class
56+
# array, but I am not sure this is guaranteed).
57+
58+
arr = np.arange(3).view(subclass_self)
59+
60+
arr0d = np.array(3, dtype=np.int8).view(subclass_arr)
61+
# With third argument True, ndarray allows "decay" to scalar.
62+
# (I don't think NumPy would pass `None`, but it seems clear to support)
63+
if subclass_self is np.ndarray:
64+
assert type(arr.__array_wrap__(arr0d, None, True)) is np.int8
65+
else:
66+
assert type(arr.__array_wrap__(arr0d, None, True)) is type(arr)
67+
68+
# Otherwise, result should be viewed as the subclass
69+
assert type(arr.__array_wrap__(arr0d)) is type(arr)
70+
assert type(arr.__array_wrap__(arr0d, None, None)) is type(arr)
71+
assert type(arr.__array_wrap__(arr0d, None, False)) is type(arr)
72+
73+
# Non 0-D array can't be converted to scalar, so we ignore that
74+
arr1d = np.array([3], dtype=np.int8).view(subclass_arr)
75+
assert type(arr.__array_wrap__(arr1d, None, True)) is type(arr)

numpy/_core/tests/test_multiarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9807,7 +9807,7 @@ class MyArr(np.ndarray):
98079807

98089808
def __array_wrap__(self, new, context=None, return_scalar=False):
98099809
type(self).called_wrap += 1
9810-
return super().__array_wrap__(new)
9810+
return super().__array_wrap__(new, context, return_scalar)
98119811

98129812
numpy_arr = np.zeros(5, dtype=dt1)
98139813
my_arr = np.zeros(5, dtype=dt2).view(MyArr)

numpy/_core/tests/test_scalar_methods.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,24 @@ def test_to_device(self, scalar):
223223
@pytest.mark.parametrize("scalar", scalars)
224224
def test___array_namespace__(self, scalar):
225225
assert scalar.__array_namespace__() is np
226+
227+
228+
@pytest.mark.parametrize("scalar", [np.bool(True), np.int8(1), np.float64(1)])
229+
def test_array_wrap(scalar):
230+
# Test scalars array wrap as long as it exists. NumPy itself should
231+
# probably not use it, so it may not be necessary to keep it around.
232+
233+
arr0d = np.array(3, dtype=np.int8)
234+
# Third argument not passed, None, or True "decays" to scalar.
235+
# (I don't think NumPy would pass `None`, but it seems clear to support)
236+
assert type(scalar.__array_wrap__(arr0d)) is np.int8
237+
assert type(scalar.__array_wrap__(arr0d, None, None)) is np.int8
238+
assert type(scalar.__array_wrap__(arr0d, None, True)) is np.int8
239+
240+
# Otherwise, result should be the input
241+
assert scalar.__array_wrap__(arr0d, None, False) is arr0d
242+
243+
# An old bug. A non 0-d array cannot be converted to scalar:
244+
arr1d = np.array([3], dtype=np.int8)
245+
assert scalar.__array_wrap__(arr1d) is arr1d
246+
assert scalar.__array_wrap__(arr1d, None, True) is arr1d

0 commit comments

Comments
 (0)
0