8000 BUG: Convert non-array rhs for boolean assignment with correct dtype · seberg/numpy@8362e08 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8362e08

Browse files
committed
BUG: Convert non-array rhs for boolean assignment with correct dtype
Enforcing the left hand side datatype for a non-array right hand side argument in index assignments was the behavior before 1.7. and is the general behaviour here. (note this means a non-array right hand side checks for NaN, etc. if the left hand side is integer, but an array right hand side does not)
1 parent 00a8c0c commit 8362e08

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

numpy/core/src/multiarray/mapping.c

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,9 +1306,17 @@ array_ass_sub(PyArrayObject *self, PyObject *ind, PyObject *op)
13061306
PyArrayObject *op_arr;
13071307
PyArray_Descr *dtype = NULL;
13081308

1309-
op_arr = (PyArrayObject *)PyArray_FromAny(op, dtype, 0, 0, 0, NULL);
1310-
if (op_arr == NULL) {
1311-
return -1;
1309+
if (!PyArray_Check(op)) {
1310+
dtype = PyArray_DTYPE(self);
1311+
Py_INCREF(dtype);
1312+
op_arr = (PyArrayObject *)PyArray_FromAny(op, dtype, 0, 0, 0, NULL);
1313+
if (op_arr == NULL) {
1314+
return -1;
1315+
}
1316+
}
1317+
else {
1318+
op_arr = op;
1319+
Py_INCREF(op_arr);
13121320
}
13131321

13141322
if (PyArray_NDIM(op_arr) < 2) {

numpy/core/tests/test_regression.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ def bfb(): x[:] = np.arange(3,dtype=float)
304304
self.assertRaises(ValueError, bfa)
305305
self.assertRaises(ValueError, bfb)
306306

307+
def test_nonarray_assignment(self):
308+
# See also Issue gh-2870, test for nonarray assignment
309+
# and equivalent unsafe casted array assignment
310+
a = np.arange(10)
311+
b = np.ones(10, dtype=bool)
312+
def assign(a, b, c):
313+
a[b] = c
314+
assert_raises(ValueError, assign, a, b, np.nan)
315+
a[b] = np.array(np.nan) # but not this.
316+
307317
def test_unpickle_dtype_with_object(self,level=rlevel):
308318
"""Implemented in r2840"""
309319
dt = np.dtype([('x',int),('y',np.object_),('z','O')])

0 commit comments

Comments
 (0)
0