diff --git a/numpy/_core/src/umath/ufunc_type_resolution.c b/numpy/_core/src/umath/ufunc_type_resolution.c index f6f231223f63..c02559f47c05 100644 --- a/numpy/_core/src/umath/ufunc_type_resolution.c +++ b/numpy/_core/src/umath/ufunc_type_resolution.c @@ -1427,6 +1427,25 @@ PyUFunc_RemainderTypeResolver(PyUFuncObject *ufunc, } +static PyObject *default_truediv_type_tup = NULL; + +NPY_NO_EXPORT int +init_ufunc_type_resolution_cache() { + /* Initialize default_truediv_type_tup global */ + PyArray_Descr *tmp = PyArray_DescrFromType(NPY_DOUBLE); + + if (tmp == NULL) { + return -1; + } + default_truediv_type_tup = PyTuple_Pack(3, tmp, tmp, tmp); + if (default_truediv_type_tup == NULL) { + Py_DECREF(tmp); + return -1; + } + Py_DECREF(tmp); + return 0; +} + /* * True division should return float64 results when both inputs are integer * types. The PyUFunc_DefaultTypeResolver promotes 8 bit integers to float16 @@ -1441,22 +1460,6 @@ PyUFunc_TrueDivisionTypeResolver(PyUFuncObject *ufunc, PyArray_Descr **out_dtypes) { int type_num1, type_num2; - static PyObject *default_type_tup = NULL; - - /* Set default type for integer inputs to NPY_DOUBLE */ - if (default_type_tup == NULL) { - PyArray_Descr *tmp = PyArray_DescrFromType(NPY_DOUBLE); - - if (tmp == NULL) { - return -1; - } - default_type_tup = PyTuple_Pack(3, tmp, tmp, tmp); - if (default_type_tup == NULL) { - Py_DECREF(tmp); - return -1; - } - Py_DECREF(tmp); - } type_num1 = PyArray_DESCR(operands[0])->type_num; type_num2 = PyArray_DESCR(operands[1])->type_num; @@ -1465,7 +1468,7 @@ PyUFunc_TrueDivisionTypeResolver(PyUFuncObject *ufunc, (PyTypeNum_ISINTEGER(type_num1) || PyTypeNum_ISBOOL(type_num1)) && (PyTypeNum_ISINTEGER(type_num2) || PyTypeNum_ISBOOL(type_num2))) { return PyUFunc_DefaultTypeResolver(ufunc, casting, operands, - default_type_tup, out_dtypes); + default_truediv_type_tup, out_dtypes); } return PyUFunc_DivisionTypeResolver(ufunc, casting, operands, type_tup, out_dtypes); diff --git a/numpy/_core/src/umath/ufunc_type_resolution.h b/numpy/_core/src/umath/ufunc_type_resolution.h index 3f8e7505ea39..f18e481c90bd 100644 --- a/numpy/_core/src/umath/ufunc_type_resolution.h +++ b/numpy/_core/src/umath/ufunc_type_resolution.h @@ -71,6 +71,9 @@ PyUFunc_MultiplicationTypeResolver(PyUFuncObject *ufunc, PyObject *type_tup, PyArray_Descr **out_dtypes); +NPY_NO_EXPORT int +init_ufunc_type_resolution_cache(); + NPY_NO_EXPORT int PyUFunc_TrueDivisionTypeResolver(PyUFuncObject *ufunc, NPY_CASTING casting, diff --git a/numpy/_core/src/umath/umathmodule.c b/numpy/_core/src/umath/umathmodule.c index 7c774f9fffc3..bb28c36a5e02 100644 --- a/numpy/_core/src/umath/umathmodule.c +++ b/numpy/_core/src/umath/umathmodule.c @@ -31,6 +31,7 @@ #include "stringdtype_ufuncs.h" #include "special_integer_comparisons.h" #include "extobj.h" /* for _extobject_contextvar exposure */ +#include "ufunc_type_resolution.h" /* Automatically generated code to define all ufuncs: */ #include "funcs.inc" @@ -346,5 +347,9 @@ int initumath(PyObject *m) return -1; } + if (init_ufunc_type_resolution_cache() < 0) { + return -1; + } + return 0; }