8000 Merge pull request #27715 from ngoldbaum/fix-generic-fancy-index-cast · numpy/numpy@da32320 · GitHub
[go: up one dir, main page]

Skip to content

Commit da32320

Browse files
authored
Merge pull request #27715 from ngoldbaum/fix-generic-fancy-index-cast
BUG: fix incorrect output descriptor in fancy indexing
2 parents 060c28a + 6a85517 commit da32320

File tree

3 files changed

+30
-11
lines changed

3 files changed

+30
-11
lines changed

numpy/_core/src/multiarray/mapping.c

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,7 +1668,7 @@ array_subscript(PyArrayObject *self, PyObject *op)
16681668

16691669
if (PyArray_GetDTypeTransferFunction(1,
16701670
itemsize, itemsize,
1671-
PyArray_DESCR(self), PyArray_DESCR(self),
1671+
PyArray_DESCR(self), PyArray_DESCR(mit->extra_op),
16721672
0, &cast_info, &transfer_flags) != NPY_SUCCEED) {
16731673
goto finish;
16741674
}
@@ -2035,7 +2035,6 @@ array_assign_subscript(PyArrayObject *self, PyObject *ind, PyObject *op)
20352035
goto fail;
20362036
}
20372037

2038-
int allocated_array = 0;
20392038
if (tmp_arr == NULL) {
20402039
/* Fill extra op, need to swap first */
20412040
tmp_arr = mit->extra_op;
@@ -2049,7 +2048,11 @@ array_assign_subscript(PyArrayObject *self, PyObject *ind, PyObject *op)
20492048
if (PyArray_CopyObject(tmp_arr, op) < 0) {
20502049
goto fail;
20512050
}
2052-
allocated_array = 1;
2051+
/*
2052+
* In this branch we copy directly from a newly allocated array which
2053+
* may have a new descr:
2054+
*/
2055+
descr = PyArray_DESCR(tmp_arr);
20532056
}
20542057

20552058
if (PyArray_MapIterCheckIndices(mit) < 0) {
@@ -2097,8 +2100,7 @@ array_assign_subscript(PyArrayObject *self, PyObject *ind, PyObject *op)
20972100
// for non-REFCHK user DTypes. See gh-27057 for the prior discussion about this.
20982101
if (PyArray_GetDTypeTransferFunction(
20992102
1, itemsize, itemsize,
2100-
allocated_array ? PyArray_DESCR(mit->extra_op) : PyArray_DESCR(self),
2101-
PyArray_DESCR(self),
2103+
descr, PyArray_DESCR(self),
21022104
0, &cast_info, &transfer_flags) != NPY_SUCCEED) {
21032105
goto fail;
21042106
}

numpy/_core/src/multiarray/nditer_constr.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,8 +1315,10 @@ npyiter_check_casting(int nop, PyArrayObject **op,
13151315
printf("\n");
13161316
#endif
13171317
/* If the types aren't equivalent, a cast is necessary */
1318-
if (op[iop] != NULL && !PyArray_EquivTypes(PyArray_DESCR(op[iop]),
1319-
op_dtype[iop])) {
1318+
npy_intp view_offset = NPY_MIN_INTP;
1319+
if (op[iop] != NULL && !(PyArray_SafeCast(
1320+
PyArray_DESCR(op[iop]), op_dtype[iop], &view_offset,
1321+
NPY_NO_CASTING, 1) && view_offset == 0)) {
13201322
/* Check read (op -> temp) casting */
13211323
if ((op_itflags[iop] & NPY_OP_ITFLAG_READ) &&
13221324
!PyArray_CanCastArrayTo(op[iop],

numpy/_core/tests/test_stringdtype.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -495,15 +495,30 @@ def test_fancy_indexing(string_list):
495495
sarr = np.array(string_list, dtype="T")
496496
assert_array_equal(sarr, sarr[np.arange(sarr.shape[0])])
497497

498+
inds = [
499+
[True, True],
500+
[0, 1],
501+
...,
502+
np.array([0, 1], dtype='uint8'),
503+
]
504+
505+
lops = [
506+
['a'*25, 'b'*25],
507+
['', ''],
508+
['hello', 'world'],
509+
['hello', 'world'*25],
510+
]
511+
498512
# see gh-27003 and gh-27053
499-
for ind in [[True, True], [0, 1], ...]:
500-
for lop in [['a'*16, 'b'*16], ['', '']]:
513+
for ind in inds:
514+
for lop in lops:
501515
a = np.array(lop, dtype="T")
502-
rop = ['d'*16, 'e'*16]
516+
assert_array_equal(a[ind], a)
517+
rop = ['d'*25, 'e'*25]
503518
for b in [rop, np.array(rop, dtype="T")]:
504519
a[ind] = b
505520
assert_array_equal(a, b)
506-
assert a[0] == 'd'*16
521+
assert a[0] == 'd'*25
507522

508523

509524
def test_creation_functions():

0 commit comments

Comments
 (0)
0