8000 BUG: fix choose refcount leak by charris · Pull Request #24328 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
BUG: fix choose refcount leak
* Fixes #22683

* use `copyswap` to avoid the reference count leaking
reported above when `np.choose` is used with `out`

* my impression from the ticket is that Sebastian doesn't
think `copyswap` is a perfect solution, but may suffice
short-term?
  • Loading branch information
tylerjereddy authored and charris committed Aug 3, 2023
commit ec9c0252db71dce8e738d0e2209c5b72156724e8
8 changes: 5 additions & 3 deletions numpy/core/src/multiarray/item_selection.c
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,8 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
{
PyArrayObject *obj = NULL;
PyArray_Descr *dtype;
int n, elsize;
PyArray_CopySwapFunc *copyswap;
int n, elsize, swap;
npy_intp i;
char *ret_data;
PyArrayObject **mps, *ap;
Expand Down Expand Up @@ -1042,6 +1043,8 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
}
elsize = PyArray_DESCR(obj)->elsize;
ret_data = PyArray_DATA(obj);
copyswap = dtype->f->copyswap;
swap = !PyArray_ISNBO(dtype->byteorder);

while (PyArray_MultiIter_NOTDONE(multi)) {
mi = *((npy_intp *)PyArray_MultiIter_DATA(multi, n));
Expand Down Expand Up @@ -1074,12 +1077,11 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out,
break;
}
}
memmove(ret_data, PyArray_MultiIter_DATA(multi, mi), elsize);
copyswap(ret_data, PyArray_MultiIter_DATA(multi, mi), swap, NULL);
ret_data += elsize;
PyArray_MultiIter_NEXT(multi);
}

PyArray_INCREF(obj);
Py_DECREF(multi);
for (i = 0; i < n; i++) {
Py_XDECREF(mps[i]);
Expand Down
10 changes: 10 additions & 0 deletions numpy/core/tests/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10031,3 +10031,13 @@ def test_argsort_int(N, dtype):
arr = rnd.randint(low=minv, high=maxv, size=N, dtype=dtype)
arr[N-1] = maxv
assert_arg_sorted(arr, np.argsort(arr, kind='quick'))


@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
def test_gh_22683():
a = np.ones(10000, dtype=object)
refc_start = sys.getrefcount(1)
np.choose(np.zeros(10000, dtype=int), [a], out=a)
np.choose(np.zeros(10000, dtype=int), [a], out=a)
refc_end = sys.getrefcount(1)
assert refc_end - refc_start < 10
0