8000 Merge pull request #28420 from jakirkham/backport_treddy_issue_28354_… · numpy/numpy@a9d3796 · GitHub
[go: up one dir, main page]

Skip to content

Commit a9d3796

Browse files
authored
Merge pull request #28420 from jakirkham/backport_treddy_issue_28354_2.2.x
BUG: safer bincount casting (backport to 2.2.x)
2 parents fc594d4 + 1efec00 commit a9d3796

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

numpy/_core/src/multiarray/compiled_base.c

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,12 @@ arr_bincount(PyObject *NPY_UNUSED(self), PyObject *const *args,
150150
}
151151
if (PyArray_SIZE(tmp1) > 0) {
152152
/* The input is not empty, so convert it to NPY_INTP. */
153-
lst = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)tmp1,
154-
NPY_INTP, 1, 1);
153+
int flags = NPY_ARRAY_WRITEABLE | NPY_ARRAY_ALIGNED | NPY_ARRAY_C_CONTIGUOUS;
154+
if (PyArray_ISINTEGER(tmp1)) {
155+
flags = flags | NPY_ARRAY_FORCECAST;
156+
}
157+
PyArray_Descr* local_dtype = PyArray_DescrFromType(NPY_INTP);
158+
lst = (PyArrayObject *)PyArray_FromAny((PyObject *)tmp1, local_dtype, 1, 1, flags, NULL);
155159
Py_DECREF(tmp1);
156160
if (lst == NULL) {
157161
/* Failed converting to NPY_INTP. */
@@ -177,7 +181,13 @@ arr_bincount(PyObject *NPY_UNUSED(self), PyObject *const *args,
177181
}
178182

179183
if (lst == NULL) {
180-
lst = (PyArrayObject *)PyArray_ContiguousFromAny(list, NPY_INTP, 1, 1);
184+
int flags = NPY_ARRAY_WRITEABLE | NPY_ARRAY_ALIGNED | NPY_ARRAY_C_CONTIGUOUS;
185+
if (PyArray_Check((PyObject *)list) &&
186+
PyArray_ISINTEGER((PyArrayObject *)list)) {
187+
flags = flags | NPY_ARRAY_FORCECAST;
188+
}
189+
PyArray_Descr* local_dtype = PyArray_DescrFromType(NPY_INTP);
190+
lst = (PyArrayObject *)PyArray_FromAny(list, local_dtype, 1, 1, flags, NULL);
181191
if (lst == NULL) {
182192
goto fail;
183193
}

numpy/lib/tests/test_function_base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2925,6 +2925,27 @@ def test_error_not_1d(self, vals):
29252925
with assert_raises(ValueError):
29262926
np.bincount(vals)
29272927

2928+
@pytest.mark.parametrize("dt", np.typecodes["AllInteger"])
2929+
def test_gh_28354(self, dt):
2930+
a = np.array([0, 1, 1, 3, 2, 1, 7], dtype=dt)
2931+
actual = np.bincount(a)
2932+
expected = [1, 3, 1, 1, 0, 0, 0, 1]
2933+
assert_array_equal(actual, expected)
2934+
2935+
def test_contiguous_handling(self):
2936+
# check for absence of hard crash
2937+
np.bincount(np.arange(10000)[::2])
2938+
2939+
def test_gh_28354_array_like(self):
2940+
class A:
2941+
def __array__(self):
2942+
return np.array([0, 1, 1, 3, 2, 1, 7], dtype=np.uint64)
2943+
2944+
a = A()
2945+
actual = np.bincount(a)
2946+
expected = [1, 3, 1, 1, 0, 0, 0, 1]
2947+
assert_array_equal(actual, expected)
2948+
29282949

29292950
class TestInterp:
29302951

0 commit comments

Comments
 (0)
0