8000 Merge pull request #27343 from lysnikolaou/flatindex-on-flatiter · ngoldbaum/numpy@760dbe9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 760dbe9

Browse files
authored
Merge pull request numpy#27343 from lysnikolaou/flatindex-on-flatiter
2 parents 4cffb9d + 5838393 commit 760dbe9

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

numpy/_core/src/multiarray/iterators.c

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,23 @@ iter_subscript(PyArrayIterObject *self, PyObject *ind)
694694
obj = ind;
695695
}
696696

697-
/* Any remaining valid input is an array or has been turned into one */
698697
if (!PyArray_Check(obj)) {
699-
goto fail;
698+
PyArrayObject *tmp_arr = (PyArrayObject *) PyArray_FROM_O(obj);
699+
if (tmp_arr == NULL) {
700+
goto fail;
701+
}
702+
703+
if (PyArray_SIZE(tmp_arr) == 0) {
704+
PyArray_Descr *indtype = PyArray_DescrFromType(NPY_INTP);
705+
Py_SETREF(obj, PyArray_FromArray(tmp_arr, indtype, NPY_ARRAY_FORCECAST));
706+
Py_DECREF(tmp_arr);
707+
if (obj == NULL) {
708+
goto fail;
709+
}
710+
}
711+
else {
712+
Py_SETREF(obj, (PyObject *) tmp_arr);
713+
}
700714
}
701715

702716
/* Check for Boolean array */

numpy/_core/tests/test_indexing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,22 @@ def test_nontuple_ndindex(self):
622622
assert_equal(a[[0, 1], [0, 1]], np.array([0, 6]))
623623
assert_raises(IndexError, a.__getitem__, [slice(None)])
624624

625+
def test_flat_index_on_flatiter(self):
626+
a = np.arange(9).reshape((3, 3))
627+
b = np.array([0, 5, 6])
628+
assert_equal(a.flat[b.flat], np.array([0, 5, 6]))
629+
630+
def test_empty_string_flat_index_on_flatiter(self):
631+
a = np.arange(9).reshape((3, 3))
632+
b = np.array([], dtype="S")
633+
assert_equal(a.flat[b.flat], np.array([]))
634+
635+
def test_nonempty_string_flat_index_on_flatiter(self):
636+
a = np.arange(9).reshape((3, 3))
637+
b = np.array(["a"], dtype="S")
638+
with pytest.raises(IndexError, match="unsupported iterator index"):
639+
a.flat[b.flat]
640+
625641

626642
class TestFieldIndexing:
627643
def test_scalar_return_type(self):

numpy/typing/tests/data/pass/flatiter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@
1414
a[:]
1515
a.__array__()
1616
a.__array__(np.dtype(np.float64))
17+
18+
b = np.array([1]).flat
19+
a[b]

0 commit comments

Comments
 (0)
0