8000 Merge pull request #27913 from charris/backport-27896 · numpy/numpy@7895ba6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7895ba6

Browse files
authored
Merge pull request #27913 from charris/backport-27896
PERF: improve multithreaded ufunc scaling
2 parents b30a338 + ee8d1cd commit 7895ba6

File tree

13 files changed

+152
-33
lines changed

13 files changed

+152
-33
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
* Improved multithreaded scaling on the free-threaded build when many threads
2+
simultaneously call the same ufunc operations.

numpy/_core/code_generators/genapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def get_processor():
8585
join('multiarray', 'stringdtype', 'static_string.c'),
8686
join('multiarray', 'strfuncs.c'),
8787
join('multiarray', 'usertypes.c'),
88-
join('umath', 'dispatching.c'),
88+
join('umath', 'dispatching.cpp'),
8989
join('umath', 'extobj.c'),
9090
join('umath', 'loops.c.src'),
9191
join('umath', 'reduction.c'),

numpy/_core/code_generators/generate_umath.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,13 +1592,10 @@ def make_code(funcdict, filename):
15921592
#include "matmul.h"
15931593
#include "clip.h"
15941594
#include "dtypemeta.h"
1595+
#include "dispatching.h"
1595 56D3 1596
#include "_umath_doc_generated.h"
15961597
15971598
%s
1598-
/* Returns a borrowed ref of the second value in the matching info tuple */
1599-
PyObject *
1600-
get_info_no_cast(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtype,
1601-
int ndtypes);
16021599
16031600
static int
16041601
InitOperators(PyObject *dictionary) {

numpy/_core/include/numpy/ndarraytypes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NDARRAYTYPES_H_
22
#define NUMPY_CORE_INCLUDE_NUMPY_NDARRAYTYPES_H_
33

4+
#ifdef __cplusplus
5+
extern "C" {
6+
#endif
7+
48
#include "npy_common.h"
59
#include "npy_endian.h"
610
#include "npy_cpu.h"
@@ -1922,4 +1926,8 @@ typedef struct {
19221926
*/
19231927
#undef NPY_DEPRECATED_INCLUDES
19241928

1929+
#ifdef __cplusplus
1930+
}
1931+
#endif
1932+
19251933
#endif /* NUMPY_CORE_INCLUDE_NUMPY_NDARRAYTYPES_H_ */

numpy/_core/meson.build

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ py.extension_module('_multiarray_tests',
713713
src_file.process('src/multiarray/_multiarray_tests.c.src'),
714714
'src/common/mem_overlap.c',
715715
'src/common/npy_argparse.c',
716-
'src/common/npy_hashtable.c',
716+
'src/common/npy_hashtable.cpp',
717717
src_file.process('src/common/templ_common.h.src')
718718
],
719719
c_args: c_args_common,
@@ -1042,7 +1042,7 @@ src_multiarray_umath_common = [
10421042
'src/common/gil_utils.c',
10431043
'src/common/mem_overlap.c',
10441044
'src/common/npy_argparse.c',
1045-
'src/common/npy_hashtable.c',
1045+
'src/common/npy_hashtable.cpp',
10461046
'src/common/npy_import.c',
10471047
'src/common/npy_longdouble.c',
10481048
'src/common/ucsnarrow.c',
@@ -1153,7 +1153,7 @@ src_umath = umath_gen_headers + [
11531153
'src/umath/ufunc_type_resolution.c',
11541154
'src/umath/clip.cpp',
11551155
'src/umath/clip.h',
1156-
'src/umath/dispatching.c',
1156+
'src/umath/dispatching.cpp',
11571157
'src/umath/extobj.c',
11581158
'src/umath/legacy_array_method.c',
11591159
'src/umath/override.c',

numpy/_core/src/common/npy_hashtable.c renamed to numpy/_core/src/common/npy_hashtable.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
* case is likely desired.
1313
*/
1414

15+
#include <mutex>
16+
#include <shared_mutex>
17+
1518
#include "templ_common.h"
1619
#include "npy_hashtable.h"
1720

@@ -89,7 +92,7 @@ find_item(PyArrayIdentityHash const *tb, PyObject *const *key)
8992
NPY_NO_EXPORT PyArrayIdentityHash *
9093
PyArrayIdentityHash_New(int key_len)
9194
{
92-
PyArrayIdentityHash *res = PyMem_Malloc(sizeof(PyArrayIdentityHash));
95+
PyArrayIdentityHash *res = (PyArrayIdentityHash *)PyMem_Malloc(sizeof(PyArrayIdentityHash));
9396
if (res == NULL) {
9497
PyErr_NoMemory();
9598
return NULL;
@@ -100,12 +103,21 @@ PyArrayIdentityHash_New(int key_len)
100103
res->size = 4; /* Start with a size of 4 */
101104
res->nelem = 0;
102105

103-
res->buckets = PyMem_Calloc(4 * (key_len + 1), sizeof(PyObject *));
106+
res->buckets = (PyObject **)PyMem_Calloc(4 * (key_len + 1), sizeof(PyObject *));
104107
if (res->buckets == NULL) {
105108
PyErr_NoMemory();
106109
PyMem_Free(res);
107110
return NULL;
108111
}
112+
113+
#ifdef Py_GIL_DISABLED
114+
res->mutex = new(std::nothrow) std::shared_mutex();
115+
if (res->mutex == nullptr) {
116+
PyErr_NoMemory();
117+
PyMem_Free(res);
118+
return NULL;
119+
}
120+
#endif
109121
return res;
110122
}
111123

@@ -115,6 +127,9 @@ PyArrayIdentityHash_Dealloc(PyArrayIdentityHash *tb)
115127
{
116128
PyMem_Free(tb->buckets);
117129
PyMem_Free(tb);
130+
#ifdef Py_GIL_DISABLED
131+
delete (std::shared_mutex *)tb->mutex;
132+
#endif
118133
}
119134

120135

@@ -149,7 +164,7 @@ _resize_if_necessary(PyArrayIdentityHash *tb)
149164
if (npy_mul_sizes_with_overflow(&alloc_size, new_size, tb->key_len + 1)) {
150165
return -1;
151166
}
152-
tb->buckets = PyMem_Calloc(alloc_size, sizeof(PyObject *));
167+
tb->buckets = (PyObject **)PyMem_Calloc(alloc_size, sizeof(PyObject *));
153168
if (tb->buckets == NULL) {
154169
tb->buckets = old_table;
155170
PyErr_NoMemory();

numpy/_core/src/common/npy_hashtable.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,19 @@
77
#include "numpy/ndarraytypes.h"
88

99

10+
#ifdef __cplusplus
11+
extern "C" {
12+
#endif
13+
1014
typedef struct {
1115
int key_len; /* number of identities used */
1216
/* Buckets stores: val1, key1[0], key1[1], ..., val2, key2[0], ... */
1317
PyObject **buckets;
1418
npy_intp size; /* current size */
1519
npy_intp nelem; /* number of elements */
20+
#ifdef Py_GIL_DISABLED
21+
void *mutex;
22+
#endif
1623
} PyArrayIdentityHash;
1724

1825

@@ -29,4 +36,8 @@ PyArrayIdentityHash_New(int key_len);
2936
NPY_NO_EXPORT void
3037
PyArrayIdentityHash_Dealloc(PyArrayIdentityHash *tb);
3138

39+
#ifdef __cplusplus
40+
}
41+
#endif
42+
3243
#endif /* NUMPY_CORE_SRC_COMMON_NPY_NPY_HASHTABLE_H_ */

numpy/_core/src/multiarray/common.h

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
#include "npy_import.h"
1313
#include <limits.h>
1414

15+
#ifdef __cplusplus
16+
extern "C" {
17+
#endif
18+
1519
#define error_converting(x) (((x) == -1) && PyErr_Occurred())
1620

1721
#ifdef NPY_ALLOW_THREADS
@@ -104,13 +108,13 @@ check_and_adjust_index(npy_intp *index, npy_intp max_item, int axis,
104108
/* Try to be as clear as possible about what went wrong. */
105109
if (axis >= 0) {
106110
PyErr_Format(PyExc_IndexError,
107-
"index %"NPY_INTP_FMT" is out of bounds "
108-
"for axis %d with size %"NPY_INTP_FMT,
111+
"index %" NPY_INTP_FMT" is out of bounds "
112+
"for axis %d with size %" NPY_INTP_FMT,
109113
*index, axis, max_item);
110114
} else {
111115
PyErr_Format(PyExc_IndexError,
112-
"index %"NPY_INTP_FMT" is out of bounds "
113-
"for size %"NPY_INTP_FMT, *index, max_item);
116+
"index %" NPY_INTP_FMT " is out of bounds "
117+
"for size %" NPY_INTP_FMT, *index, max_item);
114118
}
115119
return -1;
116120
}
@@ -163,7 +167,9 @@ check_and_adjust_axis(int *axis, int ndim)
163167
* <https://gcc.gnu.org/bugzilla/show_bug.cgi?id=52023>.
164168
* clang versions < 8.0.0 have the same bug.
165169
*/
166-
#if (!defined __STDC_VERSION__ || __STDC_VERSION__ < 201112 \
170+
#ifdef __cplusplus
171+
#define NPY_ALIGNOF(type) alignof(type)
172+
#elif (!defined __STDC_VERSION__ || __STDC_VERSION__ < 201112 \
167173
|| (defined __GNUC__ && __GNUC__ < 4 + (__GNUC_MINOR__ < 9) \
168174
&& !defined __clang__) \
169175
|| (defined __clang__ && __clang_major__ < 8))
@@ -347,4 +353,8 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out,
347353
*/
348354
#define NPY_ITER_REDUCTION_AXIS(axis) (axis + (1 << (NPY_BITSOF_INT - 2)))
349355

356+
#ifdef __cplusplus
357+
}
358+
#endif
359+
350360
#endif /* NUMPY_CORE_SRC_MULTIARRAY_COMMON_H_ */

numpy/_core/src/multiarray/npy_static_data.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#ifndef NUMPY_CORE_SRC_MULTIARRAY_STATIC_DATA_H_
22
#define NUMPY_CORE_SRC_MULTIARRAY_STATIC_DATA_H_
33

4+
#ifdef __cplusplus
5+
extern "C" {
6+
#endif
7+
48
NPY_NO_EXPORT int
59
initialize_static_globals(void);
610

@@ -168,4 +172,8 @@ NPY_VISIBILITY_HIDDEN extern npy_interned_str_struct npy_interned_str;
168172
NPY_VISIBILITY_HIDDEN extern npy_static_pydata_struct npy_static_pydata;
169173
NPY_VISIBILITY_HIDDEN extern npy_static_cdata_struct npy_static_cdata;
170174

175+
#ifdef __cplusplus
176+
}
177+
#endif
178+
171179
#endif // NUMPY_CORE_SRC_MULTIARRAY_STATIC_DATA_H_

numpy/_core/src/umath/dispatching.c renamed to numpy/_core/src/umath/dispatching.cpp

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
#define _MULTIARRAYMODULE
3939
#define _UMATHMODULE
4040

41+
#include <mutex>
42+
#include <shared_mutex>
43+
4144
#define PY_SSIZE_T_CLEAN
4245
#include <Python.h>
4346
#include <convert_datatype.h>
@@ -504,8 +507,9 @@ call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *info,
504507
PyObject *promoter = PyTuple_GET_ITEM(info, 1);
505508
if (PyCapsule_CheckExact(promoter)) {
506509
/* We could also go the other way and wrap up the python function... */
507-
PyArrayMethod_PromoterFunction *promoter_function = PyCapsule_GetPointer(
508-
promoter, "numpy._ufunc_promoter");
510+
PyArrayMethod_PromoterFunction *promoter_function =
511+
(PyArrayMethod_PromoterFunction *)PyCapsule_GetPointer(
512+
promoter, "numpy._ufunc_promoter");
509513
if (promoter_function == NULL) {
510514
return NULL;
511515
}
@@ -770,8 +774,9 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
770774
* 2. Check all registered loops/promoters to find the best match.
771775
* 3. Fall back to the legacy implementation if no match was found.
772776
*/
773-
PyObject *info = PyArrayIdentityHash_GetItem(ufunc->_dispatch_cache,
774-
(PyObject **)op_dtypes);
777+
PyObject *info = PyArrayIdentityHash_GetItem(
778+
(PyArrayIdentityHash *)ufunc->_dispatch_cache,
779+
(PyObject **)op_dtypes);
775780
if (info != NULL && PyObject_TypeCheck(
776781
PyTuple_GET_ITEM(info, 1), &PyArrayMethod_Type)) {
777782
/* Found the ArrayMethod and NOT a promoter: return it */
@@ -793,8 +798,9 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
793798
* Found the ArrayMethod and NOT promoter. Before returning it
794799
* add it to the cache for faster lookup in the future.
795800
*/
796-
if (PyArrayIdentityHash_SetItem(ufunc->_dispatch_cache,
797-
(PyObject **)op_dtypes, info, 0) < 0) {
801+
if (PyArrayIdentityHash_SetItem(
802+
(PyArrayIdentityHash *)ufunc->_dispatch_cache,
803+
(PyObject **)op_dtypes, info, 0) < 0) {
798804
return NULL;
799805
}
800806
return info;
@@ -815,8 +821,9 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
815821
}
816822
else if (info != NULL) {
817823
/* Add result to the cache using the original types: */
818-
if (PyArrayIdentityHash_SetItem(ufunc->_dispatch_cache,
819-
(PyObject **)op_dtypes, info, 0) < 0) {
824+
if (PyArrayIdentityHash_SetItem(
825+
(PyArrayIdentityHash *)ufunc->_dispatch_cache,
826+
(PyObject **)op_dtypes, info, 0) < 0) {
820827
return NULL;
821828
}
822829
return info;
@@ -882,13 +889,51 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
882889
}
883890

884891
/* Add this to the cache using the original types: */
885-
if (cacheable && PyArrayIdentityHash_SetItem(ufunc->_dispatch_cache,
886-
(PyObject **)op_dtypes, info, 0) < 0) {
892+
if (cacheable && PyArrayIdentityHash_SetItem(
893+
(PyArrayIdentityHash *)ufunc->_dispatch_cache,
894+
(PyObject **)op_dtypes, info, 0) < 0) {
887895
return NULL;
888896
}
889897
return info;
890898
}
891899

900+
#ifdef Py_GIL_DISABLED
901+
/*
902+
* Fast path for promote_and_get_info_and_ufuncimpl.
903+
* Acquires a read lock to check for a cache hit and then
904+
* only acquires a write lock on a cache miss to fill the cache
905+
*/
906+
static inline PyObject *
907+
promote_and_get_info_and_ufuncimpl_with_locking(< BCA1 /div>
908+
PyUFuncObject *ufunc,
909+
PyArrayObject *const ops[],
910+
PyArray_DTypeMeta *signature[],
911+
PyArray_DTypeMeta *op_dtypes[],
912+
npy_bool legacy_promotion_is_possible)
913+
{
914+
std::shared_mutex *mutex = ((std::shared_mutex *)((PyArrayIdentityHash *)ufunc->_dispatch_cache)->mutex);
915+
mutex->lock_shared();
916+
PyObject *info = PyArrayIdentityHash_GetItem(
917+
(PyArrayIdentityHash *)ufunc->_dispatch_cache,
918+
(PyObject **)op_dtypes);
919+
mutex->unlock_shared();
920+
921+
if (info != NULL && PyObject_TypeCheck(
922+
PyTuple_GET_ITEM(info, 1), &PyArrayMethod_Type)) {
923+
/* Found the ArrayMethod and NOT a promoter: return it */
924+
return info;
925+
}
926+
927+
// cache miss, need to acquire a write lock and recursively calculate the
928+
// correct dispatch resolution
929+
mutex->lock();
930+
info = promote_and_get_info_and_ufuncimpl(ufunc,
931+
ops, signature, op_dtypes, legacy_promotion_is_possible);
932+
mutex->unlock();
933+
934+
return info;
935+
}
936+
#endif
892937

893938
/**
894939
* The central entry-point for the promotion and dispatching machinery.
@@ -941,6 +986,8 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
941986
{
942987
int nin = ufunc->nin, nargs = ufunc->nargs;
943988
npy_bool legacy_promotion_is_possible = NPY_TRUE;
989+
PyObject *all_dtypes = NULL;
990+
PyArrayMethodObject *method = NULL;
944991

945992
/*
946993
* Get the actual DTypes we operate with by setting op_dtypes[i] from
@@ -976,18 +1023,20 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
9761023
}
9771024
}
9781025

979-
PyObject *info;
980-
Py_BEGIN_CRITICAL_SECTION((PyObject *)ufunc);
981-
info = promote_and_get_info_and_ufuncimpl(ufunc,
1026+
#ifdef Py_GIL_DISABLED
1027+
PyObject *info = promote_and_get_info_and_ufuncimpl_with_locking(ufunc,
1028+
ops, signature, op_dtypes, legacy_promotion_is_possible);
1029+
#else
1030+
PyObject *info = promote_and_get_info_and_ufuncimpl(ufunc,
9821031
ops, signature, op_dtypes, legacy_promotion_is_possible);
983-
Py_END_CRITICAL_SECTION();
1032+
#endif
9841033

9851034
if (info == NULL) {
9861035
goto handle_error;
9871036
}
9881037

989-
PyArrayMethodObject *method = (PyArrayMethodObject *)PyTuple_GET_ITEM(info, 1);
990-
PyObject *all_dtypes = PyTuple_GET_ITEM(info, 0);
1038+
method = (PyArrayMethodObject *)PyTuple_GET_ITEM(info, 1);
1039+
all_dtypes = PyTuple_GET_ITEM(info, 0);
9911040

9921041
/*
9931042
* In certain cases (only the logical ufuncs really), the loop we found may
@@ -1218,7 +1267,7 @@ install_logical_ufunc_promoter(PyObject *ufunc)
12181267
if (dtype_tuple == NULL) {
12191268
return -1;
12201269
}
1221-
PyObject *promoter = PyCapsule_New(&logical_ufunc_promoter,
1270+
PyObject *promoter = PyCapsule_New((void *)&logical_ufunc_promoter,
12221271
"numpy._ufunc_promoter", NULL);
12231272
if (promoter == NULL) {
12241273
Py_DECREF(dtype_tuple);

0 commit comments

Comments
 (0)
0