diff --git a/numpy/_core/include/numpy/_dtype_api.h b/numpy/_core/include/numpy/_dtype_api.h index 2bd06e1a1158..c1cb11ae8fbf 100644 --- a/numpy/_core/include/numpy/_dtype_api.h +++ b/numpy/_core/include/numpy/_dtype_api.h @@ -5,7 +5,7 @@ #ifndef NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_ #define NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_ -#define __EXPERIMENTAL_DTYPE_API_VERSION 13 +#define __EXPERIMENTAL_DTYPE_API_VERSION 14 struct PyArrayMethodObject_tag; @@ -129,16 +129,17 @@ typedef struct { * SLOTS IDs For the ArrayMethod creation, once fully public, IDs are fixed * but can be deprecated and arbitrarily extended. */ -#define NPY_METH_resolve_descriptors 1 +#define _NPY_METH_resolve_descriptors_with_scalars 1 +#define NPY_METH_resolve_descriptors 2 /* We may want to adapt the `get_loop` signature a bit: */ -#define _NPY_METH_get_loop 2 -#define NPY_METH_get_reduction_initial 3 +#define _NPY_METH_get_loop 3 +#define NPY_METH_get_reduction_initial 4 /* specific loops for constructions/default get_loop: */ -#define NPY_METH_strided_loop 4 -#define NPY_METH_contiguous_loop 5 -#define NPY_METH_unaligned_strided_loop 6 -#define NPY_METH_unaligned_contiguous_loop 7 -#define NPY_METH_contiguous_indexed_loop 8 +#define NPY_METH_strided_loop 5 +#define NPY_METH_contiguous_loop 6 +#define NPY_METH_unaligned_strided_loop 7 +#define NPY_METH_unaligned_contiguous_loop 8 +#define NPY_METH_contiguous_indexed_loop 9 /* * The resolve descriptors function, must be able to handle NULL values for @@ -162,6 +163,30 @@ typedef NPY_CASTING (resolve_descriptors_function)( npy_intp *view_offset); +/* + * Rarely needed, slightly more powerful version of `resolve_descriptors`. + * See also `resolve_descriptors_function` for details on shared arguments. + * + * NOTE: This function is private now as it is unclear how and what to pass + * exactly as additional information to allow dealing with the scalars. + * See also gh-24915. + */ +typedef NPY_CASTING (resolve_descriptors_with_scalars_function)( + struct PyArrayMethodObject_tag *method, + PyArray_DTypeMeta **dtypes, + /* Unlike above, these can have any DType and we may allow NULL. */ + PyArray_Descr **given_descrs, + /* + * Input scalars or NULL. Only ever passed for python scalars. + * WARNING: In some cases, a loop may be explicitly selected and the + * value passed is not available (NULL) or does not have the + * expected type. + */ + PyObject *const *input_scalars, + PyArray_Descr **loop_descrs, + npy_intp *view_offset); + + typedef int (PyArrayMethod_StridedLoop)(PyArrayMethod_Context *context, char *const *data, const npy_intp *dimensions, const npy_intp *strides, NpyAuxData *transferdata); diff --git a/numpy/_core/meson.build b/numpy/_core/meson.build index 13b32adc7290..56dabf9c168a 100644 --- a/numpy/_core/meson.build +++ b/numpy/_core/meson.build @@ -1097,6 +1097,7 @@ src_umath = umath_gen_headers + [ src_file.process('src/umath/scalarmath.c.src'), 'src/umath/ufunc_object.c', 'src/umath/umathmodule.c', + 'src/umath/special_integer_comparisons.cpp', 'src/umath/string_ufuncs.cpp', 'src/umath/wrapping_array_method.c', # For testing. Eventually, should use public API and be separate: diff --git a/numpy/_core/src/multiarray/abstractdtypes.h b/numpy/_core/src/multiarray/abstractdtypes.h index a3f6ceb056e1..212994a422ea 100644 --- a/numpy/_core/src/multiarray/abstractdtypes.h +++ b/numpy/_core/src/multiarray/abstractdtypes.h @@ -5,6 +5,10 @@ #include "dtypemeta.h" +#ifdef __cplusplus +extern "C" { +#endif + /* * These are mainly needed for value based promotion in ufuncs. It * may be necessary to make them (partially) public, to allow user-defined @@ -70,4 +74,8 @@ npy_mark_tmp_array_if_pyscalar( return 0; } +#ifdef __cplusplus +} +#endif + #endif /* NUMPY_CORE_SRC_MULTIARRAY_ABSTRACTDTYPES_H_ */ diff --git a/numpy/_core/src/multiarray/array_method.c b/numpy/_core/src/multiarray/array_method.c index 87515bb03aa8..24a7203c0d60 100644 --- a/numpy/_core/src/multiarray/array_method.c +++ b/numpy/_core/src/multiarray/array_method.c @@ -221,12 +221,6 @@ validate_spec(PyArrayMethod_Spec *spec) "(method: %s)", spec->dtypes[i], spec->name); return -1; } - if (NPY_DT_is_abstract(spec->dtypes[i])) { - PyErr_Format(PyExc_TypeError, - "abstract DType %S are currently not supported." - "(method: %s)", spec->dtypes[i], spec->name); - return -1; - } } return 0; } @@ -261,6 +255,16 @@ fill_arraymethod_from_slots( */ for (PyType_Slot *slot = &spec->slots[0]; slot->slot != 0; slot++) { switch (slot->slot) { + case _NPY_METH_resolve_descriptors_with_scalars: + if (!private) { + PyErr_SetString(PyExc_RuntimeError, + "the _NPY_METH_resolve_descriptors_with_scalars " + "slot private due to uncertainty about the best " + "signature (see gh-24915)"); + return -1; + } + meth->resolve_descriptors_with_scalars = slot->pfunc; + continue; case NPY_METH_resolve_descriptors: meth->resolve_descriptors = slot->pfunc; continue; @@ -272,7 +276,6 @@ fill_arraymethod_from_slots( * (as in: we should not worry about changing it, but of * course that would not break it immediately.) */ - /* Only allow override for private functions initially */ meth->get_strided_loop = slot->pfunc; continue; /* "Typical" loops, supported used by the default `get_loop` */ diff --git a/numpy/_core/src/multiarray/array_method.h b/numpy/_core/src/multiarray/array_method.h index c82a968cd136..a25bfe91d527 100644 --- a/numpy/_core/src/multiarray/array_method.h +++ b/numpy/_core/src/multiarray/array_method.h @@ -45,6 +45,7 @@ typedef struct PyArrayMethodObject_tag { NPY_CASTING casting; /* default flags. The get_strided_loop function can override these */ NPY_ARRAYMETHOD_FLAGS flags; + resolve_descriptors_with_scalars_function *resolve_descriptors_with_scalars; resolve_descriptors_function *resolve_descriptors; get_loop_function *get_strided_loop; get_reduction_initial_function *get_reduction_initial; diff --git a/numpy/_core/src/multiarray/arrayobject.h b/numpy/_core/src/multiarray/arrayobject.h index b71354a5e4dd..476b87a9d7e1 100644 --- a/numpy/_core/src/multiarray/arrayobject.h +++ b/numpy/_core/src/multiarray/arrayobject.h @@ -5,6 +5,10 @@ #ifndef NUMPY_CORE_SRC_MULTIARRAY_ARRAYOBJECT_H_ #define NUMPY_CORE_SRC_MULTIARRAY_ARRAYOBJECT_H_ +#ifdef __cplusplus +extern "C" { +#endif + extern NPY_NO_EXPORT npy_bool numpy_warn_if_no_mem_policy; NPY_NO_EXPORT PyObject * @@ -51,4 +55,8 @@ static const int NPY_ARRAY_WAS_PYTHON_COMPLEX = (1 << 28); static const int NPY_ARRAY_WAS_INT_AND_REPLACED = (1 << 27); static const int NPY_ARRAY_WAS_PYTHON_LITERAL = (1 << 30 | 1 << 29 | 1 << 28); +#ifdef __cplusplus +} +#endif + #endif /* NUMPY_CORE_SRC_MULTIARRAY_ARRAYOBJECT_H_ */ diff --git a/numpy/_core/src/umath/dispatching.c b/numpy/_core/src/umath/dispatching.c index 9556ec0b8a9b..b85774e8e531 100644 --- a/numpy/_core/src/umath/dispatching.c +++ b/numpy/_core/src/umath/dispatching.c @@ -152,6 +152,13 @@ PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate) */ NPY_NO_EXPORT int PyUFunc_AddLoopFromSpec(PyObject *ufunc, PyArrayMethod_Spec *spec) +{ + return PyUFunc_AddLoopFromSpec_int(ufunc, spec, 0); +} + + +NPY_NO_EXPORT int +PyUFunc_AddLoopFromSpec_int(PyObject *ufunc, PyArrayMethod_Spec *spec, int priv) { if (!PyObject_TypeCheck(ufunc, &PyUFunc_Type)) { PyErr_SetString(PyExc_TypeError, @@ -159,7 +166,7 @@ PyUFunc_AddLoopFromSpec(PyObject *ufunc, PyArrayMethod_Spec *spec) return -1; } PyBoundArrayMethodObject *bmeth = - (PyBoundArrayMethodObject *)PyArrayMethod_FromSpec(spec); + (PyBoundArrayMethodObject *)PyArrayMethod_FromSpec_int(spec, priv); if (bmeth == NULL) { return -1; } @@ -275,6 +282,12 @@ resolve_implementation_info(PyUFuncObject *ufunc, == PyTuple_GET_ITEM(curr_dtypes, 2)) { continue; } + /* + * This should be a reduce, but doesn't follow the reduce + * pattern. So (for now?) consider this not a match. + */ + matches = NPY_FALSE; + continue; } if (resolver_dtype == (PyArray_DTypeMeta *)Py_None) { diff --git a/numpy/_core/src/umath/dispatching.h b/numpy/_core/src/umath/dispatching.h index 513b50d75341..3e84106da59a 100644 --- a/numpy/_core/src/umath/dispatching.h +++ b/numpy/_core/src/umath/dispatching.h @@ -20,6 +20,9 @@ PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate); NPY_NO_EXPORT int PyUFunc_AddLoopFromSpec(PyObject *ufunc, PyArrayMethod_Spec *spec); +NPY_NO_EXPORT int +PyUFunc_AddLoopFromSpec_int(PyObject *ufunc, PyArrayMethod_Spec *spec, int priv); + NPY_NO_EXPORT PyArrayMethodObject * promote_and_get_ufuncimpl(PyUFuncObject *ufunc, PyArrayObject *const ops[], diff --git a/numpy/_core/src/umath/legacy_array_method.h b/numpy/_core/src/umath/legacy_array_method.h index 498fb1aa27c2..750de06c7992 100644 --- a/numpy/_core/src/umath/legacy_array_method.h +++ b/numpy/_core/src/umath/legacy_array_method.h @@ -5,13 +5,14 @@ #include "numpy/ufuncobject.h" #include "array_method.h" +#ifdef __cplusplus +extern "C" { +#endif NPY_NO_EXPORT PyArrayMethodObject * PyArray_NewLegacyWrappingArrayMethod(PyUFuncObject *ufunc, PyArray_DTypeMeta *signature[]); - - /* * The following two symbols are in the header so that other places can use * them to probe for special cases (or whether an ArrayMethod is a "legacy" @@ -29,5 +30,8 @@ NPY_NO_EXPORT NPY_CASTING wrapped_legacy_resolve_descriptors(PyArrayMethodObject *, PyArray_DTypeMeta **, PyArray_Descr **, PyArray_Descr **, npy_intp *); +#ifdef __cplusplus +} +#endif #endif /*_NPY_LEGACY_ARRAY_METHOD_H */ diff --git a/numpy/_core/src/umath/scalarmath.c.src b/numpy/_core/src/umath/scalarmath.c.src index 743ecc128659..f43e1493db2d 100644 --- a/numpy/_core/src/umath/scalarmath.c.src +++ b/numpy/_core/src/umath/scalarmath.c.src @@ -1842,6 +1842,7 @@ static PyObject * * LONG, ULONG, LONGLONG, ULONGLONG, * HALF, FLOAT, DOUBLE, LONGDOUBLE, * CFLOAT, CDOUBLE, CLONGDOUBLE# + * #isint = 1*10, 0*7# * #simp = def*10, def_half, def*3, fcmplx, cmplx, lcmplx# */ #define IS_@name@ @@ -1852,6 +1853,27 @@ static PyObject* npy_@name@ arg1, arg2; int out = 0; +#if @isint@ + /* Special case comparison with python integers */ + if (PyLong_CheckExact(other)) { + PyObject *self_val = PyNumber_Index(self); + if (self_val == NULL) { + return NULL; + } + int res = PyObject_RichCompareBool(self_val, other, cmp_op); + Py_DECREF(self_val); + if (res < 0) { + return NULL; + } + else if (res) { + PyArrayScalar_RETURN_TRUE; + } + else { + PyArrayScalar_RETURN_FALSE; + } + } +#endif + /* * Extract the other value (if it is compatible). */ diff --git a/numpy/_core/src/umath/special_integer_comparisons.cpp b/numpy/_core/src/umath/special_integer_comparisons.cpp new file mode 100644 index 000000000000..f706aa706815 --- /dev/null +++ b/numpy/_core/src/umath/special_integer_comparisons.cpp @@ -0,0 +1,473 @@ +#include + +#define NPY_NO_DEPRECATED_API NPY_API_VERSION +#define _MULTIARRAYMODULE +#define _UMATHMODULE + +#include "numpy/ndarraytypes.h" +#include "numpy/npy_math.h" +#include "numpy/ufuncobject.h" + +#include "abstractdtypes.h" +#include "dispatching.h" +#include "dtypemeta.h" +#include "common_dtype.h" +#include "convert_datatype.h" + +#include "legacy_array_method.h" /* For `get_wrapped_legacy_ufunc_loop`. */ +#include "special_integer_comparisons.h" + + +/* + * Helper for templating, avoids warnings about uncovered switch paths. + */ +enum class COMP { + EQ, NE, LT, LE, GT, GE, +}; + +static char const * +comp_name(COMP comp) { + switch(comp) { + case COMP::EQ: return "equal"; + case COMP::NE: return "not_equal"; + case COMP::LT: return "less"; + case COMP::LE: return "less_equal"; + case COMP::GT: return "greater"; + case COMP::GE: return "greater_equal"; + default: + assert(0); + return nullptr; + } +} + + +template +static int +fixed_result_loop(PyArrayMethod_Context *NPY_UNUSED(context), + char *const data[], npy_intp const dimensions[], + npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata)) +{ + npy_intp N = dimensions[0]; + char *out = data[2]; + npy_intp stride = strides[2]; + + while (N--) { + *reinterpret_cast(out) = result; + out += stride; + } + return 0; +} + +static inline void +get_min_max(int typenum, long long *min, unsigned long long *max) +{ + *min = 0; + switch (typenum) { + case NPY_BYTE: + *min = NPY_MIN_BYTE; + *max = NPY_MAX_BYTE; + break; + case NPY_UBYTE: + *max = NPY_MAX_UBYTE; + break; + case NPY_SHORT: + *min = NPY_MIN_SHORT; + *max = NPY_MAX_SHORT; + break; + case NPY_USHORT: + *max = NPY_MAX_USHORT; + break; + case NPY_INT: + *min = NPY_MIN_INT; + *max = NPY_MAX_INT; + break; + case NPY_UINT: + *max = NPY_MAX_UINT; + break; + case NPY_LONG: + *min = NPY_MIN_LONG; + *max = NPY_MAX_LONG; + break; + case NPY_ULONG: + *max = NPY_MAX_ULONG; + break; + case NPY_LONGLONG: + *min = NPY_MIN_LONGLONG; + *max = NPY_MAX_LONGLONG; + break; + case NPY_ULONGLONG: + *max = NPY_MAX_ULONGLONG; + break; + default: + assert(0); + } +} + + +/* + * Determine if a Python long is within the typenums range, smaller, or larger. + * + * Function returns -1 for errors. + */ +static inline int +get_value_range(PyObject *value, int type_num, int *range) +{ + long long min; + unsigned long long max; + get_min_max(type_num, &min, &max); + + int overflow; + long long val = PyLong_AsLongLongAndOverflow(value, &overflow); + if (val == -1 && overflow == 0 && PyErr_Occurred()) { + return -1; + } + + if (overflow == 0) { + if (val < min) { + *range = -1; + } + else if (val > 0 && (unsigned long long)val > max) { + *range = 1; + } + else { + *range = 0; + } + } + else if (overflow < 0) { + *range = -1; + } + else if (max <= NPY_MAX_LONGLONG) { + *range = 1; + } + else { + /* + * If we are checking for unisgned long long, the value may be larger + * then long long, but within range of unsigned long long. Check this + * by doing the normal Python integer comparison. + */ + PyObject *obj = PyLong_FromUnsignedLongLong(max); + if (obj == NULL) { + return -1; + } + int cmp = PyObject_RichCompareBool(value, obj, Py_GT); + Py_DECREF(obj); + if (cmp < 0) { + return -1; + } + if (cmp) { + *range = 1; + } + else { + *range = 0; + } + } + return 0; +} + + +/* + * Find the type resolution for any numpy_int with pyint comparison. This + * function supports *both* directions for all types. + */ +static NPY_CASTING +resolve_descriptors_with_scalars( + PyArrayMethodObject *self, PyArray_DTypeMeta **dtypes, + PyArray_Descr **given_descrs, PyObject *const *input_scalars, + PyArray_Descr **loop_descrs, npy_intp *view_offset) +{ + int value_range = 0; + + npy_bool first_is_pyint = dtypes[0] == &PyArray_PyIntAbstractDType; + int arr_idx = first_is_pyint? 1 : 0; + int scalar_idx = first_is_pyint? 0 : 1; + PyObject *scalar = input_scalars[scalar_idx]; + assert(PyTypeNum_ISINTEGER(dtypes[arr_idx]->type_num)); + PyArray_DTypeMeta *arr_dtype = dtypes[arr_idx]; + + /* + * Three way decision (with hack) on value range: + * 0: The value fits within the range of the dtype. + * 1: The value came second and is larger or came first and is smaller. + * -1: The value came second and is smaller or came first and is larger + */ + if (scalar != NULL && PyLong_CheckExact(scalar)) { + if (get_value_range(scalar, arr_dtype->type_num, &value_range) < 0) { + return _NPY_ERROR_OCCURRED_IN_CAST; + } + if (first_is_pyint == 1) { + value_range *= -1; + } + } + + /* + * Very small/large values always need to be encoded as `object` dtype + * in order to never fail casting (NumPy will store the Python integer + * in a 0-D object array this way -- even if we never inspect it). + * + * TRICK: We encode the value range by whether or not we use the object + * singleton! This information is then available in `get_loop()` + * to pick a loop that returns always True or False. + */ + if (value_range == 0) { + Py_INCREF(arr_dtype->singleton); + loop_descrs[scalar_idx] = arr_dtype->singleton; + } + else if (value_range < 0) { + loop_descrs[scalar_idx] = PyArray_DescrFromType(NPY_OBJECT); + } + else { + loop_descrs[scalar_idx] = PyArray_DescrNewFromType(NPY_OBJECT); + if (loop_descrs[scalar_idx] == NULL) { + return _NPY_ERROR_OCCURRED_IN_CAST; + } + } + Py_INCREF(arr_dtype->singleton); + loop_descrs[arr_idx] = arr_dtype->singleton; + loop_descrs[2] = PyArray_DescrFromType(NPY_BOOL); + + return NPY_NO_CASTING; +} + + +template +static int +get_loop(PyArrayMethod_Context *context, + int aligned, int move_references, const npy_intp *strides, + PyArrayMethod_StridedLoop **out_loop, NpyAuxData **out_transferdata, + NPY_ARRAYMETHOD_FLAGS *flags) +{ + if (context->descriptors[1]->type_num == context->descriptors[0]->type_num) { + /* + * Fall back to the current implementation, which wraps legacy loops. + */ + return get_wrapped_legacy_ufunc_loop( + context, aligned, move_references, strides, + out_loop, out_transferdata, flags); + } + else { + PyArray_Descr *other_descr; + if (context->descriptors[1]->type_num == NPY_OBJECT) { + other_descr = context->descriptors[1]; + } + else { + assert(context->descriptors[0]->type_num == NPY_OBJECT); + other_descr = context->descriptors[0]; + } + /* HACK: If the descr is the singleton the result is smaller */ + PyArray_Descr *obj_singleton = PyArray_DescrFromType(NPY_OBJECT); + if (other_descr == obj_singleton) { + /* Singleton came second and is smaller, or first and is larger */ + switch (comp) { + case COMP::EQ: + case COMP::LT: + case COMP::LE: + *out_loop = &fixed_result_loop; + break; + case COMP::NE: + case COMP::GT: + case COMP::GE: + *out_loop = &fixed_result_loop; + break; + } + } + else { + /* Singleton came second and is larger, or first and is smaller */ + switch (comp) { + case COMP::EQ: + case COMP::GT: + case COMP::GE: + *out_loop = &fixed_result_loop; + break; + case COMP::NE: + case COMP::LT: + case COMP::LE: + *out_loop = &fixed_result_loop; + break; + } + } + Py_DECREF(obj_singleton); + } + *flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + return 0; +} + + +/* + * Machinery to add the python integer to NumPy intger comparsisons as well + * as a special promotion to special case Python int with Python int + * comparisons. + */ + +/* + * Simple promoter that ensures we use the object loop when the input + * is python integers only. + * Note that if a user would pass the Python `int` abstract DType explicitly + * they promise to actually pass a Python int and we accept that we never + * check for that. + */ +static int +pyint_comparison_promoter(PyUFuncObject *NPY_UNUSED(ufunc), + PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[], + PyArray_DTypeMeta *new_op_dtypes[]) +{ + new_op_dtypes[0] = PyArray_DTypeFromTypeNum(NPY_OBJECT); + new_op_dtypes[1] = PyArray_DTypeFromTypeNum(NPY_OBJECT); + new_op_dtypes[2] = PyArray_DTypeFromTypeNum(NPY_BOOL); + return 0; +} + + +/* + * This function replaces the strided loop with the passed in one, + * and registers it with the given ufunc. + * It additionally adds a promoter for (pyint, pyint, bool) to use the + * (object, object, bool) implementation. + */ +template +static int +add_dtype_loops(PyObject *umath, PyArrayMethod_Spec *spec, PyObject *info) +{ + PyArray_DTypeMeta *PyInt = &PyArray_PyIntAbstractDType; + + PyObject *name = PyUnicode_FromString(comp_name(comp)); + if (name == nullptr) { + return -1; + } + PyUFuncObject *ufunc = (PyUFuncObject *)PyObject_GetItem(umath, name); + Py_DECREF(name); + if (ufunc == nullptr) { + return -1; + } + if (Py_TYPE(ufunc) != &PyUFunc_Type) { + PyErr_SetString(PyExc_RuntimeError, + "internal NumPy error: comparison not a ufunc"); + goto fail; + } + + /* + * NOTE: Iterates all type numbers, it would be nice to reduce this. + * (that would be easier if we consolidate int DTypes in general.) + */ + for (int typenum = NPY_BYTE; typenum <= NPY_ULONGLONG; typenum++) { + spec->slots[0].pfunc = (void *)get_loop; + + PyArray_DTypeMeta *Int = PyArray_DTypeFromTypeNum(typenum); + + /* Register the spec/loop for both forward and backward direction */ + spec->dtypes[0] = Int; + spec->dtypes[1] = PyInt; + int res = PyUFunc_AddLoopFromSpec_int((PyObject *)ufunc, spec, 1); + if (res < 0) { + Py_DECREF(Int); + goto fail; + } + spec->dtypes[0] = PyInt; + spec->dtypes[1] = Int; + res = PyUFunc_AddLoopFromSpec_int((PyObject *)ufunc, spec, 1); + Py_DECREF(Int); + if (res < 0) { + goto fail; + } + } + + /* + * Install the promoter info to allow two Python integers to work. + */ + return PyUFunc_AddLoop((PyUFuncObject *)ufunc, info, 0); + + Py_DECREF(ufunc); + return 0; + + fail: + Py_DECREF(ufunc); + return -1; +} + + +template +struct add_loops; + +template<> +struct add_loops<> { + int operator()(PyObject*, PyArrayMethod_Spec*, PyObject *) { + return 0; + } +}; + + +template +struct add_loops { + int operator()(PyObject* umath, PyArrayMethod_Spec* spec, PyObject *info) { + if (add_dtype_loops(umath, spec, info) < 0) { + return -1; + } + else { + return add_loops()(umath, spec, info); + } + } +}; + + +NPY_NO_EXPORT int +init_special_int_comparisons(PyObject *umath) +{ + int res = -1; + PyObject *info = NULL, *promoter = NULL; + PyArray_DTypeMeta *Bool = PyArray_DTypeFromTypeNum(NPY_BOOL); + + /* All loops have a boolean out DType (others filled in later) */ + PyArray_DTypeMeta *dtypes[] = {NULL, NULL, Bool}; + /* + * We only have one loop right now, the strided one. The default type + * resolver ensures native byte order/canonical representation. + */ + PyType_Slot slots[] = { + {_NPY_METH_get_loop, nullptr}, + {_NPY_METH_resolve_descriptors_with_scalars, + (void *)&resolve_descriptors_with_scalars}, + {0, NULL}, + }; + + PyArrayMethod_Spec spec = {}; + spec.name = "templated_pyint_to_integers_comparisons"; + spec.nin = 2; + spec.nout = 1; + spec.dtypes = dtypes; + spec.slots = slots; + spec.flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + + /* + * The following sets up the correct promoter to make comparisons like + * `np.equal(2, 4)` (with two python integers) use an object loop. + */ + PyObject *dtype_tuple = PyTuple_Pack(3, + &PyArray_PyIntAbstractDType, &PyArray_PyIntAbstractDType, Bool); + if (dtype_tuple == NULL) { + goto finish; + } + promoter = PyCapsule_New( + (void *)&pyint_comparison_promoter, "numpy._ufunc_promoter", NULL); + if (promoter == NULL) { + Py_DECREF(dtype_tuple); + goto finish; + } + info = PyTuple_Pack(2, dtype_tuple, promoter); + Py_DECREF(dtype_tuple); + Py_DECREF(promoter); + if (info == NULL) { + goto finish; + } + + /* Add all combinations of PyInt and NumPy integer comparisons */ + using comp_looper = add_loops; + if (comp_looper()(umath, &spec, info) < 0) { + goto finish; + } + + res = 0; + finish: + + Py_XDECREF(info); + Py_DECREF(Bool); + return res; +} diff --git a/numpy/_core/src/umath/special_integer_comparisons.h b/numpy/_core/src/umath/special_integer_comparisons.h new file mode 100644 index 000000000000..2312bcae1e65 --- /dev/null +++ b/numpy/_core/src/umath/special_integer_comparisons.h @@ -0,0 +1,15 @@ +#ifndef _NPY_CORE_SRC_UMATH_SPECIAL_COMPARISONS_H_ +#define _NPY_CORE_SRC_UMATH_SPECIAL_COMPARISONS_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +NPY_NO_EXPORT int +init_special_int_comparisons(PyObject *umath); + +#ifdef __cplusplus +} +#endif + +#endif /* _NPY_CORE_SRC_UMATH_SPECIAL_COMPARISONS_H_ */ diff --git a/numpy/_core/src/umath/ufunc_object.c b/numpy/_core/src/umath/ufunc_object.c index 62e7f23ec918..eacd8bc9969e 100644 --- a/numpy/_core/src/umath/ufunc_object.c +++ b/numpy/_core/src/umath/ufunc_object.c @@ -118,7 +118,8 @@ static int resolve_descriptors(int nop, PyUFuncObject *ufunc, PyArrayMethodObject *ufuncimpl, PyArrayObject *operands[], PyArray_Descr *dtypes[], - PyArray_DTypeMeta *signature[], NPY_CASTING casting); + PyArray_DTypeMeta *signature[], PyObject *inputs_tup, + NPY_CASTING casting); /*UFUNC_API*/ @@ -2804,7 +2805,7 @@ reducelike_promote_and_resolve(PyUFuncObject *ufunc, * (although this should possibly happen through a deprecation) */ if (resolve_descriptors(3, ufunc, ufuncimpl, - ops, out_descrs, signature, casting) < 0) { + ops, out_descrs, signature, NULL, casting) < 0) { return NULL; } @@ -4478,11 +4479,58 @@ static int resolve_descriptors(int nop, PyUFuncObject *ufunc, PyArrayMethodObject *ufuncimpl, PyArrayObject *operands[], PyArray_Descr *dtypes[], - PyArray_DTypeMeta *signature[], NPY_CASTING casting) + PyArray_DTypeMeta *signature[], PyObject *inputs_tup, + NPY_CASTING casting) { int retval = -1; + NPY_CASTING safety; PyArray_Descr *original_dtypes[NPY_MAXARGS]; + NPY_UF_DBG_PRINT("Resolving the descriptors\n"); + + if (NPY_UNLIKELY(ufuncimpl->resolve_descriptors_with_scalars != NULL)) { + /* + * Allow a somewhat more powerful approach which: + * 1. Has access to scalars (currently only ever Python ones) + * 2. Can in principle customize `PyArray_CastDescrToDType()` + * (also because we want to avoid calling it for the scalars). + */ + int nin = ufunc->nin; + PyObject *input_scalars[NPY_MAXARGS]; + for (int i = 0; i < nop; i++) { + if (operands[i] == NULL) { + original_dtypes[i] = NULL; + } + else { + /* For abstract DTypes, we might want to change what this is */ + original_dtypes[i] = PyArray_DTYPE(operands[i]); + Py_INCREF(original_dtypes[i]); + } + if (i < nin + && NPY_DT_is_abstract(signature[i]) + && inputs_tup != NULL) { + /* + * TODO: We may wish to allow any scalar here. Checking for + * abstract assumes this works out for Python scalars, + * which is the important case (especially for now). + * + * One possible check would be `DType->type == type(obj)`. + */ + input_scalars[i] = PyTuple_GET_ITEM(inputs_tup, i); + } + else { + input_scalars[i] = NULL; + } + } + + npy_intp view_offset = NPY_MIN_INTP; /* currently ignored */ + safety = ufuncimpl->resolve_descriptors_with_scalars( + ufuncimpl, signature, original_dtypes, input_scalars, + dtypes, &view_offset + ); + goto check_safety; + } + for (int i = 0; i < nop; ++i) { if (operands[i] == NULL) { original_dtypes[i] = NULL; @@ -4501,26 +4549,13 @@ resolve_descriptors(int nop, } } - NPY_UF_DBG_PRINT("Resolving the descriptors\n"); - if (ufuncimpl->resolve_descriptors != &wrapped_legacy_resolve_descriptors) { /* The default: use the `ufuncimpl` as nature intended it */ npy_intp view_offset = NPY_MIN_INTP; /* currently ignored */ - NPY_CASTING safety = ufuncimpl->resolve_descriptors(ufuncimpl, + safety = ufuncimpl->resolve_descriptors(ufuncimpl, signature, original_dtypes, dtypes, &view_offset); - if (safety < 0) { - goto finish; - } - if (NPY_UNLIKELY(PyArray_MinCastSafety(safety, casting) != casting)) { - /* TODO: Currently impossible to reach (specialized unsafe loop) */ - PyErr_Format(PyExc_TypeError, - "The ufunc implementation for %s with the given dtype " - "signature is not possible under the casting rule %s", - ufunc_get_name_cstr(ufunc), npy_casting_to_string(casting)); - goto finish; - } - retval = 0; + goto check_safety; } else { /* @@ -4528,7 +4563,22 @@ resolve_descriptors(int nop, * for datetime64/timedelta64 and custom ufuncs (in pyerfa/astropy). */ retval = ufunc->type_resolver(ufunc, casting, operands, NULL, dtypes); + goto finish; + } + + check_safety: + if (safety < 0) { + goto finish; + } + if (NPY_UNLIKELY(PyArray_MinCastSafety(safety, casting) != casting)) { + /* TODO: Currently impossible to reach (specialized unsafe loop) */ + PyErr_Format(PyExc_TypeError, + "The ufunc implementation for %s with the given dtype " + "signature is not possible under the casting rule %s", + ufunc_get_name_cstr(ufunc), npy_casting_to_string(casting)); + goto finish; } + retval = 0; finish: for (int i = 0; i < nop; i++) { @@ -4857,7 +4907,7 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc, /* Find the correct descriptors for the operation */ if (resolve_descriptors(nop, ufunc, ufuncimpl, - operands, operation_descrs, signature, casting) < 0) { + operands, operation_descrs, signature, full_args.in, casting) < 0) { goto fail; } @@ -6229,7 +6279,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) /* Find the correct operation_descrs for the operation */ int resolve_result = resolve_descriptors(nop, ufunc, ufuncimpl, - tmp_operands, operation_descrs, signature, NPY_UNSAFE_CASTING); + tmp_operands, operation_descrs, signature, NULL, NPY_UNSAFE_CASTING); for (int i = 0; i < 3; i++) { Py_XDECREF(signature[i]); Py_XDECREF(operand_DTypes[i]); @@ -6558,7 +6608,8 @@ py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context, /* Find the correct descriptors for the operation */ if (resolve_descriptors(ufunc->nargs, ufunc, ufuncimpl, - dummy_arrays, operation_descrs, signature, casting) < 0) { + dummy_arrays, operation_descrs, signature, + NULL, casting) < 0) { goto finish; } diff --git a/numpy/_core/src/umath/umathmodule.c b/numpy/_core/src/umath/umathmodule.c index 07a9159b0dcc..1be81a7f1a24 100644 --- a/numpy/_core/src/umath/umathmodule.c +++ b/numpy/_core/src/umath/umathmodule.c @@ -27,6 +27,7 @@ #include "number.h" #include "dispatching.h" #include "string_ufuncs.h" +#include "special_integer_comparisons.h" #include "extobj.h" /* for _extobject_contextvar exposure */ /* Automatically generated code to define all ufuncs: */ @@ -334,5 +335,9 @@ int initumath(PyObject *m) return -1; } + if (init_special_int_comparisons(d) < 0) { + return -1; + } + return 0; } diff --git a/numpy/_core/tests/test_half.py b/numpy/_core/tests/test_half.py index 89bed2215357..954ba5987689 100644 --- a/numpy/_core/tests/test_half.py +++ b/numpy/_core/tests/test_half.py @@ -266,8 +266,8 @@ def test_half_correctness(self): if len(a32_fail) != 0: bad_index = a32_fail[0] assert_equal(self.finite_f32, a_manual, - "First non-equal is half value %x -> %g != %g" % - (self.finite_f16[bad_index], + "First non-equal is half value 0x%x -> %g != %g" % + (a_bits[bad_index], self.finite_f32[bad_index], a_manual[bad_index])) @@ -275,8 +275,8 @@ def test_half_correctness(self): if len(a64_fail) != 0: bad_index = a64_fail[0] assert_equal(self.finite_f64, a_manual, - "First non-equal is half value %x -> %g != %g" % - (self.finite_f16[bad_index], + "First non-equal is half value 0x%x -> %g != %g" % + (a_bits[bad_index], self.finite_f64[bad_index], a_manual[bad_index])) diff --git a/numpy/_core/tests/test_nep50_promotions.py b/numpy/_core/tests/test_nep50_promotions.py index 5e2068762eeb..ca2874f84156 100644 --- a/numpy/_core/tests/test_nep50_promotions.py +++ b/numpy/_core/tests/test_nep50_promotions.py @@ -12,7 +12,7 @@ import hypothesis from hypothesis import strategies -from numpy.testing import IS_WASM +from numpy.testing import assert_array_equal, IS_WASM @pytest.fixture(scope="module", autouse=True) @@ -132,7 +132,7 @@ def test_nep50_weak_integers_with_inexact(dtype): assert res == np.inf -@pytest.mark.parametrize("op", [operator.add, operator.pow, operator.eq]) +@pytest.mark.parametrize("op", [operator.add, operator.pow]) def test_weak_promotion_scalar_path(op): # Some additional paths exercising the weak scalars. np._set_promotion_state("weak") @@ -271,3 +271,41 @@ def test_expected_promotion(expected, dtypes, optional_dtypes, data): res = np.result_type(*dtypes_sample) assert res == expected + + +@pytest.mark.parametrize("sctype", + [np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64]) +@pytest.mark.parametrize("other_val", + [-2*100, -1, 0, 9, 10, 11, 2**63, 2*100]) +@pytest.mark.parametrize("comp", + [operator.eq, operator.ne, operator.le, operator.lt, + operator.ge, operator.gt]) +def test_integer_comparison(sctype, other_val, comp): + np._set_promotion_state("weak") + + # Test that comparisons with integers (especially out-of-bound) ones + # works correctly. + val_obj = 10 + val = sctype(val_obj) + # Check that the scalar behaves the same as the python int: + assert comp(10, other_val) == comp(val, other_val) + assert comp(val, other_val) == comp(10, other_val) + # Except for the result type: + assert type(comp(val, other_val)) is np.bool_ + + # Check that the integer array and object array behave the same: + val_obj = np.array([10, 10], dtype=object) + val = val_obj.astype(sctype) + assert_array_equal(comp(val_obj, other_val), comp(val, other_val)) + assert_array_equal(comp(other_val, val_obj), comp(other_val, val)) + + +@pytest.mark.parametrize("comp", + [np.equal, np.not_equal, np.less_equal, np.less, + np.greater_equal, np.greater]) +def test_integer_integer_comparison(comp): + np._set_promotion_state("weak") + + # Test that the NumPy comparison ufuncs work with large Python integers + assert comp(2**200, -2**200) == comp(2**200, -2**200, dtype=object)