8000 PERF: improve multithreaded ufunc scaling by ngoldbaum · Pull Request #27896 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

PERF: improve multithreaded ufunc scaling #27896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 5, 2024
2 changes: 2 additions & 0 deletions doc/release/upcoming_changes/27896.performance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
* Improved multithreaded scaling on the free-threaded build when many threads
simultaneously call the same ufunc operations.
2 changes: 1 addition & 1 deletion numpy/_core/code_generators/genapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_processor():
join('multiarray', 'stringdtype', 'static_string.c'),
join('multiarray', 'strfuncs.c'),
join('multiarray', 'usertypes.c'),
join('umath', 'dispatching.c'),
join('umath', 'dispatching.cpp'),
join('umath', 'extobj.c'),
join('umath', 'loops.c.src'),
join('umath', 'reduction.c'),
Expand Down
5 changes: 1 addition & 4 deletions numpy/_core/code_generators/generate_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,13 +1592,10 @@ def make_code(funcdict, filename):
#include "matmul.h"
#include "clip.h"
#include "dtypemeta.h"
#include "dispatching.h"
#include "_umath_doc_generated.h"

%s
/* Returns a borrowed ref of the second value in the matching info tuple */
PyObject *
get_info_no_cast(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtype,
int ndtypes);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not totally clear to me if originally not including this function in dispatching.h and forward-declaring it here was an intentional decision for some reason


static int
InitOperators(PyObject *dictionary) {
Expand Down
8 changes: 8 additions & 0 deletions numpy/_core/include/numpy/ndarraytypes.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NDARRAYTYPES_H_
#define NUMPY_CORE_INCLUDE_NUMPY_NDARRAYTYPES_H_

#ifdef __cplusplus
extern "C" {
#endif

#include "npy_common.h"
#include "npy_endian.h"
#include "npy_cpu.h"
Expand Down Expand Up @@ -1922,4 +1926,8 @@ typedef struct {
*/
#undef NPY_DEPRECATED_INCLUDES

#ifdef __cplusplus
}
#endif

#endif /* NUMPY_CORE_INCLUDE_NUMPY_NDARRAYTYPES_H_ */
6 changes: 3 additions & 3 deletions numpy/_core/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ py.extension_module('_multiarray_tests',
src_file.process('src/multiarray/_multiarray_tests.c.src'),
'src/common/mem_overlap.c',
'src/common/npy_argparse.c',
'src/common/npy_hashtable.c',
'src/common/npy_hashtable.cpp',
src_file.process('src/common/templ_common.h.src')
],
c_args: c_args_common,
Expand Down Expand Up @@ -1042,7 +1042,7 @@ src_multiarray_umath_common = [
'src/common/gil_utils.c',
'src/common/mem_overlap.c',
'src/common/npy_argparse.c',
'src/common/npy_hashtable.c',
'src/common/npy_hashtable.cpp',
'src/common/npy_import.c',
'src/common/npy_longdouble.c',
'src/common/ucsnarrow.c',
Expand Down Expand Up @@ -1153,7 +1153,7 @@ src_umath = umath_gen_headers + [
'src/umath/ufunc_type_resolution.c',
'src/umath/clip.cpp',
'src/umath/clip.h',
'src/umath/dispatching.c',
'src/umath/dispatching.cpp',
'src/umath/extobj.c',
'src/umath/legacy_array_method.c',
'src/umath/override.c',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
* case is likely desired.
*/

#include <mutex>
#include <shared_mutex>

#include "templ_common.h"
#include "npy_hashtable.h"

Expand Down Expand Up @@ -89,7 +92,7 @@ find_item(PyArrayIdentityHash const *tb, PyObject *const *key)
NPY_NO_EXPORT PyArrayIdentityHash *
PyArrayIdentityHash_New(int key_len)
{
PyArrayIdentityHash *res = PyMem_Malloc(sizeof(PyArrayIdentityHash));
PyArrayIdentityHash *res = (PyArrayIdentityHash *)PyMem_Malloc(sizeof(PyArrayIdentityHash));
if (res == NULL) {
PyErr_NoMemory();
return NULL;
Expand All @@ -100,12 +103,21 @@ PyArrayIdentityHash_New(int key_len)
res->size = 4; /* Start with a size of 4 */
res->nelem = 0;< 2364 /td>

res->buckets = PyMem_Calloc(4 * (key_len + 1), sizeof(PyObject *));
res->buckets = (PyObject **)PyMem_Calloc(4 * (key_len + 1), sizeof(PyObject *));
if (res->buckets == NULL) {
PyErr_NoMemory();
PyMem_Free(res);
return NULL;
}

#ifdef Py_GIL_DISABLED
res->mutex = new(std::nothrow) std::shared_mutex();
if (res->mutex == nullptr) {
PyErr_NoMemory();
PyMem_Free(res);
return NULL;
}
#endif
return res;
}

Expand All @@ -115,6 +127,9 @@ PyArrayIdentityHash_Dealloc(PyArrayIdentityHash *tb)
{
PyMem_Free(tb->buckets);
PyMem_Free(tb);
#ifdef Py_GIL_DISABLED
delete (std::shared_mutex *)tb->mutex;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems OK to assume that delete will never throw an exception.

#endif
}


Expand Down Expand Up @@ -149,7 +164,7 @@ _resize_if_necessary(PyArrayIdentityHash *tb)
if (npy_mul_sizes_with_overflow(&alloc_size, new_size, tb->key_len + 1)) {
return -1;
}
tb->buckets = PyMem_Calloc(alloc_size, sizeof(PyObject *));
tb->buckets = (PyObject **)PyMem_Calloc(alloc_size, sizeof(PyObject *));
if (tb->buckets == NULL) {
tb->buckets = old_table;
PyErr_NoMemory();
Expand Down
11 changes: 11 additions & 0 deletions numpy/_core/src/common/npy_hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
#include "numpy/ndarraytypes.h"


#ifdef __cplusplus
extern "C" {
#endif

typedef struct {
int key_len; /* number of identities used */
/* Buckets stores: val1, key1[0], key1[1], ..., val2, key2[0], ... */
PyObject **buckets;
npy_intp size; /* current size */
npy_intp nelem; /* number of elements */
#ifdef Py_GIL_DISABLED
void *mutex;
#endif
} PyArrayIdentityHash;


Expand All @@ -29,4 +36,8 @@ PyArrayIdentityHash_New(int key_len);
NPY_NO_EXPORT void
PyArrayIdentityHash_Dealloc(PyArrayIdentityHash *tb);

#ifdef __cplusplus
}
#endif

#endif /* NUMPY_CORE_SRC_COMMON_NPY_NPY_HASHTABLE_H_ */
20 changes: 15 additions & 5 deletions numpy/_core/src/multiarray/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
#include "npy_import.h"
#include <limits.h>

#ifdef __cplusplus
extern "C" {
#endif

#define error_converting(x) (((x) == -1) && PyErr_Occurred())

#ifdef NPY_ALLOW_THREADS
Expand Down Expand Up @@ -104,13 +108,13 @@ check_and_adjust_index(npy_intp *index, npy_intp max_item, int axis,
/* Try to be as clear as possible about what went wrong. */
if (axis >= 0) {
PyErr_Format(PyExc_IndexError,
"index %"NPY_INTP_FMT" is out of bounds "
"for axis %d with size %"NPY_INTP_FMT,
"index %" NPY_INTP_FMT" is out of bounds "
"for axis %d with size %" NPY_INTP_FMT,
*index, axis, max_item);
} else {
PyErr_Format(PyExc_IndexError,
"index %"NPY_INTP_FMT" is out of bounds "
"for size %"NPY_INTP_FMT, *index, max_item);
"index %" NPY_INTP_FMT " is out of bounds "
"for size %" NPY_INTP_FMT, *index, max_item);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just replace all NPY_INTP_FMT with %zd one of these days...

}
return -1;
}
Expand Down Expand Up @@ -163,7 +167,9 @@ check_and_adjust_axis(int *axis, int ndim)
* <https://gcc.gnu.org/bugzilla/show_bug.cgi?id=52023>.
* clang versions < 8.0.0 have the same bug.
*/
#if (!defined __STDC_VERSION__ || __STDC_VERSION__ < 201112 \
#ifdef __cplusplus
#define NPY_ALIGNOF(type) alignof(type)
Copy link
Member

Wondering if we shouldn't just use alignof, but let's go with this.

#elif (!defined __STDC_VERSION__ || __STDC_VERSION__ < 201112 \
|| (defined __GNUC__ && __GNUC__ < 4 + (__GNUC_MINOR__ < 9) \
&& !defined __clang__) \
|| (defined __clang__ && __clang_major__ < 8))
Expand Down Expand Up @@ -347,4 +353,8 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out,
*/
#define NPY_ITER_REDUCTION_AXIS(axis) (axis + (1 << (NPY_BITSOF_INT - 2)))

#ifdef __cplusplus
}
#endif

#endif /* NUMPY_CORE_SRC_MULTIARRAY_COMMON_H_ */
8 changes: 8 additions & 0 deletions numpy/_core/src/multiarray/npy_static_data.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
#ifndef NUMPY_CORE_SRC_MULTIARRAY_STATIC_DATA_H_
#define NUMPY_CORE_SRC_MULTIARRAY_STATIC_DATA_H_

#ifdef __cplusplus
extern "C" {
#endif

NPY_NO_EXPORT int
initialize_static_globals(void);

Expand Down Expand Up @@ -168,4 +172,8 @@ NPY_VISIBILITY_HIDDEN extern npy_interned_str_struct npy_interned_str;
NPY_VISIBILITY_HIDDEN extern npy_static_pydata_struct npy_static_pydata;
NPY_VISIBILITY_HIDDEN extern npy_static_cdata_struct npy_static_cdata;

#ifdef __cplusplus
}
#endif

#endif // NUMPY_CORE_SRC_MULTIARRAY_STATIC_DATA_H_
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
#define _MULTIARRAYMODULE
#define _UMATHMODULE

#include <mutex>
#include <shared_mutex>

#define PY_SSIZE_T_CLEAN
#include <Python.h>
#include <convert_datatype.h>
Expand Down Expand Up @@ -504,8 +507,9 @@ call_promoter_and_recurse(PyUFuncObject *ufunc, PyObject *info,
PyObject *promoter = PyTuple_GET_ITEM(info, 1);
if (PyCapsule_CheckExact(promoter)) {
/* We could also go the other way and wrap up the python function... */
PyArrayMethod_PromoterFunction *promoter_function = PyCapsule_GetPointer(
promoter, "numpy._ufunc_promoter");
PyArrayMethod_PromoterFunction *promoter_function =
(PyArrayMethod_PromoterFunction *)PyCapsule_GetPointer(
promoter, "numpy._ufunc_promoter");
if (promoter_function == NULL) {
return NULL;
}
Expand Down Expand Up @@ -770,8 +774,9 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
* 2. Check all registered loops/promoters to find the best match.
* 3. Fall back to the legacy implementation if no match was found.
*/
PyObject *info = PyArrayIdentityHash_GetItem(ufunc->_dispatch_cache,
(PyObject **)op_dtypes);
PyObject *info = PyArrayIdentityHash_GetItem(
(PyArrayIdentityHash *)ufunc->_dispatch_cache,
(PyObject **)op_dtypes);
if (info != NULL && PyObject_TypeCheck(
PyTuple_GET_ITEM(info, 1), &PyArrayMethod_Type)) {
/* Found the ArrayMethod and NOT a promoter: return it */
Expand All @@ -793,8 +798,9 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
* Found the ArrayMethod and NOT promoter. Before returning it
* add it to the cache for faster lookup in the future.
*/
if (PyArrayIdentityHash_SetItem(ufunc->_dispatch_cache,
(PyObject **)op_dtypes, info, 0) < 0) {
if (PyArrayIdentityHash_SetItem(
(PyArrayIdentityHash *)ufunc->_dispatch_cache,
(PyObject **)op_dtypes, info, 0) < 0) {
return NULL;
}
return info;
Expand All @@ -815,8 +821,9 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
}
else if (info != NULL) {
/* Add result to the cache using the original types: */
if (PyArrayIdentityHash_SetItem(ufunc->_dispatch_cache,
(PyObject **)op_dtypes, info, 0) < 0) {
if (PyArrayIdentityHash_SetItem(
(PyArrayIdentityHash *)ufunc->_dispatch_cache,
(PyObject **)op_dtypes, info, 0) < 0) {
return NULL;
}
return info;
Expand Down Expand Up @@ -882,13 +889,51 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
}

/* Add this to the cache using the original types: */
if (cacheable && PyArrayIdentityHash_SetItem(ufunc->_dispatch_cache,
(PyObject **)op_dtypes, info, 0) < 0) {
if (cacheable && PyArrayIdentityHash_SetItem(
(PyArrayIdentityHash *)ufunc->_dispatch_cache,
(PyObject **)op_dtypes, info, 0) < 0) {
return NULL;
}
return info;
}

#ifdef Py_GIL_DISABLED
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the drive-by comment, but Py_GIL_DISABLED means you're running a on free-threading enabled Python, not necessarily that the GIL is concretely disabled, right?

If you're calling mutex->lock or mutex->lock_shared with the GIL held, then there is a risk of deadlock if, for example, promote_and_get_info_and_ufuncimpl released the GIL in another thread and is waiting to take the GIL back.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you may want to wrap calls to mutex->lock[_shared] in a Py_BEGIN_ALLOW_THREADS/Py_END_ALLOW_THREADS pair.

/*
* Fast path for promote_and_get_info_and_ufuncimpl.
* Acquires a read lock to check for a cache hit and then
* only acquires a write lock on a cache miss to fill the cache
*/
static inline PyObject *
promote_and_get_info_and_ufuncimpl_with_locking(
PyUFuncObject *ufunc,
PyArrayObject *const ops[],
PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *op_dtypes[],
npy_bool legacy_promotion_is_possible)
{
std::shared_mutex *mutex = ((std::shared_mutex *)((PyArrayIdentityHash *)ufunc->_dispatch_cache)->mutex);
mutex->lock_shared();
PyObject *info = PyArrayIdentityHash_GetItem(
(PyArrayIdentityHash *)ufunc->_dispatch_cache,
(PyObject **)op_dtypes);
mutex->unlock_shared();

if (info != NULL && PyObject_TypeCheck(
PyTuple_GET_ITEM(info, 1), &PyArrayMethod_Type)) {
/* Found the ArrayMethod and NOT a promoter: return it */
return info;
}

// cache miss, need to acquire a write lock and recursively calculate the
// correct dispatch resolution
mutex->lock();
info = promote_and_get_info_and_ufuncimpl(ufunc,
ops, signature, op_dtypes, legacy_promotion_is_possible);
mutex->unlock();

return info;
}
#endif

/**
* The central entry-point for the promotion and dispatching machinery.
Expand Down Expand Up @@ -941,6 +986,8 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
{
int nin = ufunc->nin, nargs = ufunc->nargs;
npy_bool legacy_promotion_is_possible = NPY_TRUE;
PyObject *all_dtypes = NULL;
PyArrayMethodObject *method = NULL;

/*
* Get the actual DTypes we operate with by setting op_dtypes[i] from
Expand Down Expand Up @@ -976,18 +1023,20 @@ promote_and_get_ufuncimpl(PyUFuncObject *ufunc,
}
}

PyObject *info;
Py_BEGIN_CRITICAL_SECTION((PyObject *)ufunc);
info = promote_and_get_info_and_ufuncimpl(ufunc,
#ifdef Py_GIL_DISABLED
PyObject *info = promote_and_get_info_and_ufuncimpl_with_locking(ufunc,
ops, signature, op_dtypes, legacy_promotion_is_possible);
#else
PyObject *info = promote_and_get_info_and_ufuncimpl(ufunc,
ops, signature, op_dtypes, legacy_promotion_is_possible);
Py_END_CRITICAL_SECTION();
#endif

if (info == NULL) {
goto handle_error;
}

PyArrayMethodObject *method = (PyArrayMethodObject *)PyTuple_GET_ITEM(info, 1);
PyObject *all_dtypes = PyTuple_GET_ITEM(info, 0);
method = (PyArrayMethodObject *)PyTuple_GET_ITEM(info, 1);
all_dtypes = PyTuple_GET_ITEM(info, 0);

/*
* In certain cases (only the logical ufuncs really), the loop we found may
Expand Down Expand Up @@ -1218,7 +1267,7 @@ install_logical_ufunc_promoter(PyObject *ufunc)
if (dtype_tuple == NULL) {
return -1;
}
PyObject *promoter = PyCapsule_New(&logical_ufunc_promoter,
PyObject *promoter = PyCapsule_New((void *)&logical_ufunc_promoter,
"numpy._ufunc_promoter", NULL);
if (promoter == NULL) {
Py_DECREF(dtype_tuple);
Expand Down
Loading
Loading
0