8000 BUG: move reduction initialization to ufunc initialization (#28123) · numpy/numpy@c412bed · GitHub
[go: up one dir, main page]

Skip to content

Commit c412bed

Browse files
ngoldbaumseberg
andauthored
BUG: move reduction initialization to ufunc initialization (#28123)
* BUG: move reduction initialization to ufunc initialization * MAINT: refactor to call get_initial_from_ufunc during init * TST: add test for multithreaded reductions * MAINT: fix linter * Apply suggestions from code review Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net> * MAINT: simplify further --------- Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
1 parent 2777fa6 commit c412bed

File tree

4 files changed

+72
-28
lines changed

4 files changed

+72
-28
lines changed

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5033,6 +5033,24 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
50335033
goto err;
50345034
}
50355035

5036+
/*
5037+
* Initialize the default PyDataMem_Handler capsule singleton.
5038+
*/
5039+
PyDataMem_DefaultHandler = PyCapsule_New(
5040+
&default_handler, MEM_HANDLER_CAPSULE_NAME, NULL);
5041+
if (PyDataMem_DefaultHandler == NULL) {
5042+
goto err;
5043+
}
5044+
5045+
/*
5046+
* Initialize the context-local current handler
5047+
* with the default PyDataMem_Handler capsule.
5048+
*/
5049+
current_handler = PyContextVar_New("current_allocator", PyDataMem_DefaultHandler);
5050+
if (current_handler == NULL) {
5051+
goto err;
5052+
}
5053+
50365054
if (initumath(m) != 0) {
50375055
goto err;
50385056
}
@@ -5067,7 +5085,7 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
50675085
* init_string_dtype() but that needs to happen after
50685086
* the legacy dtypemeta classes are available.
50695087
*/
5070-
5088+
50715089
if (npy_cache_import_runtime(
50725090
"numpy.dtypes", "_add_dtype_helper",
50735091
&npy_runtime_imports._add_dtype_helper) == -1) {
@@ -5081,23 +5099,6 @@ PyMODINIT_FUNC PyInit__multiarray_umath(void) {
50815099
}
50825100
PyDict_SetItemString(d, "StringDType", (PyObject *)&PyArray_StringDType);
50835101

5084-
/*
5085-
* Initialize the default PyDataMem_Handler capsule singleton.
5086-
*/
5087-
PyDataMem_DefaultHandler = PyCapsule_New(
5088-
&default_handler, MEM_HANDLER_CAPSULE_NAME, NULL);
5089-
if (PyDataMem_DefaultHandler == NULL) {
5090-
goto err;
5091-
}
5092-
/*
5093-
* Initialize the context-local current handler
5094-
* with the default PyDataMem_Handler capsule.
5095-
*/
5096-
current_handler = PyContextVar_New("current_allocator", PyDataMem_DefaultHandler);
5097-
if (current_handler == NULL) {
5098-
goto err;
5099-
}
5100-
51015102
// initialize static reference to a zero-like array
51025103
npy_static_pydata.zero_pyint_like_arr = PyArray_ZEROS(
51035104
0, NULL, NPY_DEFAULT_INT, NPY_FALSE);

numpy/_core/src/umath/legacy_array_method.c

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ get_initial_from_ufunc(
311311
}
312312
}
313313
else if (context->descriptors[0]->type_num == NPY_OBJECT
314-
&& !reduction_is_empty) {
314+
&& !reduction_is_empty) {
315315
/* Allows `sum([object()])` to work, but use 0 when empty. */
316316
Py_DECREF(identity_obj);
317317
return 0;
@@ -323,13 +323,6 @@ get_initial_from_ufunc(
323323
return -1;
324324
}
325325

326-
if (PyTypeNum_ISNUMBER(context->descriptors[0]->type_num)) {
327-
/* For numbers we can cache to avoid going via Python ints */
328-
memcpy(context->method->legacy_initial, initial,
329-
context->descriptors[0]->elsize);
330-
context->method->get_reduction_initial = &copy_cached_initial;
331-
}
332-
333326
/* Reduction can use the initial value */
334327
return 1;
335328
}
@@ -427,11 +420,47 @@ PyArray_NewLegacyWrappingArrayMethod(PyUFuncObject *ufunc,
427420
};
428421

429422
PyBoundArrayMethodObject *bound_res = PyArrayMethod_FromSpec_int(&spec, 1);
423+
430424
if (bound_res == NULL) {
431425
return NULL;
432426
}
433427
PyArrayMethodObject *res = bound_res->method;
428+
429+
// set cached initial value for numeric reductions to avoid creating
430+
// a python int in every reduction
431+
if (PyTypeNum_ISNUMBER(bound_res->dtypes[0]->type_num) &&
432+
ufunc->nin == 2 && ufunc->nout == 1) {
433+
434+
PyArray_Descr *descrs[3];
435+
436+
for (int i = 0; i < 3; i++) {
437+
// only dealing with numeric legacy dtypes so this should always be
438+
// valid
439+
descrs[i] = bound_res->dtypes[i]->singleton;
440+
}
441+
442+
PyArrayMethod_Context context = {
443+
(PyObject *)ufunc,
444+
bound_res->method,
445+
descrs,
446+
};
447+
448+
int ret = get_initial_from_ufunc(&context, 0, context.method->legacy_initial);
449+
450+
if (ret < 0) {
451+
Py_DECREF(bound_res);
452+
return NULL;
453+
}
454+
455+
// only use the cached initial value if it's valid
456+
if (ret > 0) {
457+
context.method->get_reduction_initial = &copy_cached_initial;
458+
}
459+
}
460+
461+
434462
Py_INCREF(res);
435463
Py_DECREF(bound_res);
464+
436465
return res;
437466
}

numpy/_core/tests/test_multithreading.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,17 @@ def legacy_125():
120120

121121
task1.start()
122122
task2.start()
123+
124+
def test_parallel_reduction():
125+
# gh-28041
126+
NUM_THREADS = 50
127+
128+
b = threading.Barrier(NUM_THREADS)
129+
130+
x = np.arange(1000)
131+
132+
def closure():
133+
b.wait()
134+
np.sum(x)
135+
136+
run_threaded(closure, NUM_THREADS, max_workers=NUM_THREADS)

numpy/testing/_private/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,9 +2685,9 @@ def _get_glibc_version():
26852685
_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x)
26862686

26872687

2688-
def run_threaded(func, iters, pass_count=False):
2688+
def run_threaded(func, iters, pass_count=False, max_workers=8):
26892689
"""Runs a function many times in parallel"""
2690-
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as tpe:
2690+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as tpe:
26912691
if pass_count:
26922692
futures = [tpe.submit(func, i) for i in range(iters)]
26932693
else:

0 commit comments

Comments
 (0)
0