8000 Merge pull request #6553 from yashmehrotra/partition-fix · numpy/numpy@c0e48cf · GitHub
[go: up one dir, main page]

Skip to content

Commit c0e48cf

Browse files
committed
Merge pull request #6553 from yashmehrotra/partition-fix
BUG: Fix partition and argpartition error for empty input. Closes #6530
2 parents 522a0f7 + 4d9bf8a commit c0e48cf

File tree

3 files changed

+28
-2< 8000 /div>lines changed

3 files changed

+28
-2
lines changed

numpy/core/src/multiarray/item_selection.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
809809
PyArrayIterObject *it;
810810
npy_intp size;
811811

812-
int ret = -1;
812+
int ret = 0;
813813

814814
NPY_BEGIN_THREADS_DEF;
815815

@@ -829,6 +829,7 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
829829
if (needcopy) {
830830
buffer = PyDataMem_NEW(N * elsize);
831831
if (buffer == NULL) {
832+
ret = -1;
832833
goto fail;
833834
}
834835
}
@@ -947,7 +948,7 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
947948
PyArrayIterObject *it, *rit;
948949
npy_intp size;
949950

950-
int ret = -1;
951+
int ret = 0;
951952

952953
NPY_BEGIN_THREADS_DEF;
953954

@@ -969,6 +970,7 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
969970
it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, &axis);
970971
rit = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)rop, &axis);
971972
if (it == NULL || rit == NULL) {
973+
ret = -1;
972974
goto fail;
973975
}
974976
size = it->size;
@@ -978,13 +980,15 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
978980
if (needcopy) {
979981
valbuffer = PyDataMem_NEW(N * elsize);
980982
if (valbuffer == NULL) {
983+
ret = -1;
981984
goto fail;
982985
}
983986
}
984987

985988
if (needidxbuffer) {
986989
idxbuffer = (npy_intp *)PyDataMem_NEW(N * sizeof(npy_intp));
987990
if (idxbuffer == NULL) {
991+
ret = -1;
988992
goto fail;
989993
}
990994
}

numpy/core/tests/test_item_selection.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ def test_unicode_mode(self):
6868
k = b'\xc3\xa4'.decode("UTF8")
6969
assert_raises(ValueError, d.take, 5, mode=k)
7070

71+
def test_empty_partition(self):
72+
# In reference to github issue #6530
73+
a_original = np.array([0, 2, 4, 6, 8, 10])
74+
a = a_original.copy()
75+
76+
# An empty partition should be a successful no-op
77+
a.partition(np.array([], dtype=np.int16))
78+
79+
assert_array_equal(a, a_original)
80+
81+
def test_empty_argpartition(self):
82+
# In reference to github issue #6530
83+
a = np.array([0, 2, 4, 6, 8, 10])
84+
a = a.argpartition(np.array([], dtype=np.int16))
85+
86+
b = np.array([0, 1, 2, 3, 4, 5])
87+
assert_array_equal(a, b)
88+
7189

7290
if __name__ == "__main__":
7391
run_module_suite()

numpy/core/tests/test_regression.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2173,5 +2173,9 @@ def test_leak_in_structured_dtype_comparison(self):
21732173
after = sys.getrefcount(a)
21742174
assert_equal(before, after)
21752175

2176+
def test_empty_percentile(self):
2177+
# gh-6530 / gh-6553
2178+
assert_array_equal(np.percentile(np.arange(10), []), np.array([]))
2179+
21762180
if __name__ == "__main__":
21772181
run_module_suite()

0 commit comments

Comments
 (0)
0