8000 BUG: Avoid data race in PyArray_CheckFromAny_int (#28154) · melissawm/numpy@1e10174 · GitHub
[go: up one dir, main page]

Skip to content
forked from numpy/numpy

Commit 1e10174

Browse files
authored
BUG: Avoid data race in PyArray_CheckFromAny_int (numpy#28154)
* BUG: Avoid data race in PyArray_CheckFromAny_int * TST: add test * MAINT: simplify byteswapping code in PyArray_CheckFromAny_int * MAINT: drop ISBYTESWAPPED check
1 parent bbf4836 commit 1e10174

File tree

4 files changed

+29
-13
lines changed

4 files changed

+29
-13
lines changed

numpy/_core/src/multiarray/ctors.c

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,18 +1829,12 @@ PyArray_CheckFromAny_int(PyObject *op, PyArray_Descr *in_descr,
18291829
{
18301830
PyObject *obj;
18311831
if (requires & NPY_ARRAY_NOTSWAPPED) {
1832-
if (!in_descr && PyArray_Check(op) &&
1833-
PyArray_ISBYTESWAPPED((PyArrayObject* )op)) {
1834-
in_descr = PyArray_DescrNew(PyArray_DESCR((PyArrayObject *)op));
1835-
if (in_descr == NULL) {
1836-
return NULL;
1837-
}
1838-
}
1839-
else if (in_descr && !PyArray_ISNBO(in_descr->byteorder)) {
1840-
PyArray_DESCR_REPLACE(in_descr);
1832+
if (!in_descr && PyArray_Check(op)) {
1833+
in_descr = PyArray_DESCR((PyArrayObject *)op);
1834+
Py_INCREF(in_descr);
18411835
}
1842-
if (in_descr && in_descr->byteorder != NPY_IGNORE) {
1843-
in_descr->byteorder = NPY_NATIVE;
1836+
if (in_descr) {
1837+
PyArray_DESCR_REPLACE_CANONICAL(in_descr);
18441838
}
18451839
}
18461840

numpy/_core/src/multiarray/dtypemeta.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ PyArray_SETITEM(PyArrayObject *arr, char *itemptr, PyObject *v)
285285
v, itemptr, arr);
286286
}
287287

288+
// Like PyArray_DESCR_REPLACE, but calls ensure_canonical instead of DescrNew
289+
#define PyArray_DESCR_REPLACE_CANONICAL(descr) do { \
290+
PyArray_Descr *_new_ = NPY_DT_CALL_ensure_canonical(descr); \
291+
Py_XSETREF(descr, _new_); \
292+
} while(0)
288293

289294

290295
#endif /* NUMPY_CORE_SRC_MULTIARRAY_DTYPEMETA_H_ */

numpy/_core/tests/test_multithreading.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def closure(b):
134134

135135

136136
def test_parallel_flat_iterator():
137+
# gh-28042
137138
x = np.arange(20).reshape(5, 4).T
138139

139140
def closure(b):
@@ -142,3 +143,15 @@ def closure(b):
142143
list(x.flat)
143144

144145
run_threaded(closure, outer_iterations=100, pass_barrier=True)
146+
147+
# gh-28143
148+
def prepare_args():
149+
return [np.arange(10)]
150+
151+
def closure(x, b):
152+
b.wait()
153+
for _ in range(100):
154+
y = np.arange(10)
155+
y.flat[x] = x
156+
157+
run_threaded(closure, pass_barrier=True, prepare_args=prepare_args)

numpy/testing/_private/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2687,12 +2687,16 @@ def _get_glibc_version():
26872687

26882688

26892689
def run_threaded(func, iters=8, pass_count=False, max_workers=8,
2690-
pass_barrier=False, outer_iterations=1):
2690+
pass_barrier=False, outer_iterations=1,
2691+
prepare_args=None):
26912692
"""Runs a function many times in parallel"""
26922693
for _ in range(outer_iterations):
26932694
with (concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
26942695
as tpe):
2695-
args = []
2696+
if prepare_args is None:
2697+
args = []
2698+
else:
2699+
args = prepare_args()
26962700
if pass_barrier:
26972701
if max_workers != iters:
26982702
raise RuntimeError(

0 commit comments

Comments
 (0)
0