-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
ENH: Fast paths for richcompare #17970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bf4d137
85f5414
6a61d98
ce141fb
1fb7a6c
dcd1924
9bfd81f
f5a46b1
f908bb2
708dd75
2e016b7
4bdfda8
3b6ace3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,8 @@ | |
#include "numpy/ufuncobject.h" | ||
#include "numpy/arrayscalars.h" | ||
|
||
#include "scalartypes.h" | ||
|
||
#include "npy_import.h" | ||
#include "npy_pycompat.h" | ||
|
||
|
@@ -569,6 +571,117 @@ static void | |
|
||
/*** END OF BASIC CODE **/ | ||
|
||
/** | ||
* Find the descriptor for a builtin or Python numerical scalar. | ||
* This function should only be used for scalar math fast paths. | ||
* | ||
* @param obj The object for which to find the descriptor | ||
* @param descr The descriptor that was found. | ||
* @returns 1 if the other object is a Python type and thus must not | ||
* handle the operation, otherwise 0. -1 _without_ an error set | ||
* if no descriptor was found. | ||
*/ | ||
static int | ||
descr_from_basic_scalar(PyObject *obj, PyArray_Descr **descr) | ||
{ | ||
/* TODO: We could try giving defined scalars a chance... */ | ||
int type_num = PyArray_TypeNumFromNumPyScalarType( | ||
(PyObject *)Py_TYPE(obj), 0); | ||
if (type_num != NPY_NOTYPE) { | ||
*descr = PyArray_DescrFromType(type_num); | ||
return 0; | ||
} | ||
else if (PyFloat_CheckExact(obj)) { | ||
ganesh-k13 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
*descr = PyArray_DescrFromType(NPY_DOUBLE); | ||
return 1; | ||
} | ||
else if (PyBool_Check(obj)) { | ||
*descr = PyArray_DescrFromType(NPY_BOOL); | ||
return 1; | ||
} | ||
else if (PyLong_CheckExact(obj)) { | ||
if (PyLong_AsLong(obj) == -1 && PyErr_Occurred()) { | ||
PyErr_Clear(); | ||
return -1; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess codecov is right, it would be nice to test this path by passing in a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @seberg , any idea how we can hit this reliably? We don't reach case 2 in case of raw int, at the same time, we cannot overflow a Numpy scalar while creating it. |
||
} | ||
*descr = PyArray_DescrFromType(NPY_LONG); | ||
return 1; | ||
} | ||
else if (PyComplex_CheckExact(obj)) { | ||
*descr = PyArray_DescrFromType(NPY_CDOUBLE); | ||
return 1; | ||
} | ||
return -1; | ||
} | ||
|
||
#define PyDescr_PythonRepresentable(d) (d->type_num < NPY_LONGDOUBLE || \ | ||
d->type_num == NPY_CFLOAT || d->type_num == NPY_CDOUBLE) | ||
|
||
/** | ||
* This function attempts to compare NumPy or Python scalars | ||
* The operation is done by getting value of both scalars | ||
* and calling richcompare on them. | ||
ganesh-k13 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* | ||
* @param self The first object which will be converted to scalar item | ||
* @param other The second object to compare to | ||
* @param cmp_op The value of comparison operator | ||
* @returns Py_False or Py_True if richcompare is successfull | ||
* , otherwise NULL. | ||
*/ | ||
static PyObject* | ||
do_richcompare_on_scalars(PyObject *self, PyObject *other, int cmp_op) { | ||
PyObject *cmp_item_self, *cmp_item_other, *ret=NULL; | ||
PyArray_Descr *self_descr=NULL, *other_descr=NULL; | ||
void *data_self, *data_other; | ||
int pyscalar_other, pyscalar_self; | ||
int is_complex_operands, is_equality_operator, is_python_representable; | ||
|
||
pyscalar_self = descr_from_basic_scalar(self, &self_descr); | ||
pyscalar_other = descr_from_basic_scalar(other, &other_descr); | ||
|
||
if (pyscalar_self >= 0 && pyscalar_other >= 0) { | ||
/* | ||
* If either of the operands are complex and operator is not equality, | ||
* or the operands are not representable as a native type, | ||
* python's built-in richcompare cannot be used as it is not supported. | ||
*/ | ||
is_complex_operands = (PyTypeNum_ISCOMPLEX(self_descr->type_num) || | ||
PyTypeNum_ISCOMPLEX(other_descr->type_num)); | ||
is_equality_operator = (cmp_op != Py_EQ && cmp_op != Py_NE); | ||
is_python_representable = PyDescr_PythonRepresentable(self_descr) && | ||
PyDescr_PythonRepresentable(other_descr); | ||
|
||
if (!(is_complex_operands && is_equality_operator) && is_python_representable) { | ||
/* | ||
* If the scalar is a python built-in, we can use the object as is. | ||
* Else we need to obtain the value from the operand. | ||
* Note: If we reach this point, one of the scalars must be built-in. | ||
*/ | ||
data_self = scalar_value(self, NULL); | ||
data_other = scalar_value(other, NULL); | ||
cmp_item_self = pyscalar_self == 0 ? self_descr->f->getitem(data_self, NULL): | ||
self; | ||
cmp_item_other = pyscalar_other == 0 ? other_descr->f->getitem(data_other, NULL): | ||
other; | ||
|
||
if (cmp_item_self != NULL && cmp_item_other != NULL) { | ||
ret = PyObject_RichCompare(cmp_item_self, cmp_item_other, cmp_op); | ||
} | ||
|
||
if (pyscalar_self == 0) { | ||
Py_XDECREF(cmp_item_self); | ||
} | ||
if (pyscalar_other == 0) { | ||
Py_XDECREF(cmp_item_other); | ||
} | ||
} | ||
} | ||
Py_XDECREF(self_descr); | ||
Py_XDECREF(other_descr); | ||
|
||
return ret; | ||
} | ||
|
||
|
||
/* The general strategy for commutative binary operators is to | ||
* | ||
|
@@ -587,86 +700,38 @@ static void | |
/**begin repeat | ||
* #name = byte, ubyte, short, ushort, int, uint, | ||
* long, ulong, longlong, ulonglong, | ||
* half, float, longdouble, | ||
* half, float, double, longdouble, | ||
* cfloat, cdouble, clongdouble# | ||
* #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint, | ||
* npy_long, npy_ulong, npy_longlong, npy_ulonglong, | ||
* npy_half, npy_float, npy_longdouble, | ||
* npy_half, npy_float, npy_double, npy_longdouble, | ||
* npy_cfloat, npy_cdouble, npy_clongdouble# | ||
* #Name = Byte, UByte, Short, UShort, Int, UInt, | ||
* Long, ULong, LongLong, ULongLong, | ||
* Half, Float, LongDouble, | ||
* Half, Float, Double, LongDouble, | ||
* CFloat, CDouble, CLongDouble# | ||
* #TYPE = NPY_BYTE, NPY_UBYTE, NPY_SHORT, NPY_USHORT, NPY_INT, NPY_UINT, | ||
* NPY_LONG, NPY_ULONG, NPY_LONGLONG, NPY_ULONGLONG, | ||
* NPY_HALF, NPY_FLOAT, NPY_LONGDOUBLE, | ||
* NPY_HALF, NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE, | ||
* NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE# | ||
*/ | ||
|
||
static int | ||
_@name@_convert_to_ctype(PyObject *a, @type@ *arg1) | ||
{ | ||
PyObject *temp; | ||
|
||
if (PyArray_IsScalar(a, @Name@)) { | ||
*arg1 = PyArrayScalar_VAL(a, @Name@); | ||
return 0; | ||
} | ||
else if (PyArray_IsScalar(a, Generic)) { | ||
PyArray_Descr *descr1; | ||
|
||
if (!PyArray_IsScalar(a, Number)) { | ||
return -1; | ||
} | ||
descr1 = PyArray_DescrFromTypeObject((PyObject *)Py_TYPE(a)); | ||
if (PyArray_CanCastSafely(descr1->type_num, @TYPE@)) { | ||
PyArray_CastScalarDirect(a, descr1, arg1, @TYPE@); | ||
Py_DECREF(descr1); | ||
return 0; | ||
} | ||
else { | ||
Py_DECREF(descr1); | ||
return -1; | ||
} | ||
} | ||
else if (PyArray_GetPriority(a, NPY_PRIORITY) > NPY_PRIORITY) { | ||
return -2; | ||
} | ||
else if ((temp = PyArray_ScalarFromObject(a)) != NULL) { | ||
int retval = _@name@_convert_to_ctype(temp, arg1); | ||
|
||
Py_DECREF(temp); | ||
return retval; | ||
} | ||
return -2; | ||
} | ||
|
||
/**end repeat**/ | ||
|
||
|
||
/* Same as above but added exact checks against known python types for speed */ | ||
|
||
/**begin repeat | ||
* #name = double# | ||
* #type = npy_double# | ||
* #Name = Double# | ||
* #TYPE = NPY_DOUBLE# | ||
* #PYCHECKEXACT = PyFloat_CheckExact# | ||
* #PYEXTRACTCTYPE = PyFloat_AS_DOUBLE# | ||
*/ | ||
#define _IS_@Name@ | ||
|
||
static int | ||
_@name@_convert_to_ctype(PyObject *a, @type@ *arg1) | ||
_@name@_convert_to_ctype(PyObject *a, @type@ *arg) | ||
{ | ||
PyObject *temp; | ||
|
||
if (@PYCHECKEXACT@(a)){ | ||
*arg1 = @PYEXTRACTCTYPE@(a); | ||
#if defined(_IS_Double) || defined(_IS_LongDouble) | ||
if (PyFloat_CheckExact(a)) { | ||
*arg = (@type@)PyFloat_AS_DOUBLE(a); | ||
return 0; | ||
} | ||
#endif | ||
|
||
if (PyArray_IsScalar(a, @Name@)) { | ||
*arg1 = PyArrayScalar_VAL(a, @Name@); | ||
*arg = PyArrayScalar_VAL(a, @Name@); | ||
return 0; | ||
} | ||
else if (PyArray_IsScalar(a, Generic)) { | ||
|
@@ -677,7 +742,7 @@ _@name@_convert_to_ctype(PyObject *a, @type@ *arg1) | |
} | ||
descr1 = PyArray_DescrFromTypeObject((PyObject *)Py_TYPE(a)); | ||
if (PyArray_CanCastSafely(descr1->type_num, @TYPE@)) { | ||
PyArray_CastScalarDirect(a, descr1, arg1, @TYPE@); | ||
PyArray_CastScalarDirect(a, descr1, arg, @TYPE@); | ||
Py_DECREF(descr1); | ||
return 0; | ||
} | ||
|
@@ -690,14 +755,16 @@ _@name@_convert_to_ctype(PyObject *a, @type@ *arg1) | |
return -2; | ||
} | ||
else if ((temp = PyArray_ScalarFromObject(a)) != NULL) { | ||
int retval = _@name@_convert_to_ctype(temp, arg1); | ||
int retval = _@name@_convert_to_ctype(temp, arg); | ||
|
||
Py_DECREF(temp); | ||
return retval; | ||
} | ||
return -2; | ||
} | ||
|
||
#undef _IS_@Name@ | ||
|
||
/**end repeat**/ | ||
|
||
|
||
|
@@ -1308,14 +1375,20 @@ static PyObject* | |
{ | ||
npy_@name@ arg1, arg2; | ||
int out=0; | ||
PyObject *ret; | ||
|
||
RICHCMP_GIVE_UP_IF_NEEDED(self, other); | ||
|
||
switch(_@name@_convert2_to_ctypes(self, &arg1, other, &arg2)) { | ||
case 0: | ||
break; | ||
case -1: | ||
/* can't cast both safely use different add function */ | ||
/* can't cast both safely to same type. | ||
* Try fastpath else use ufuncs */ | ||
ret = do_richcompare_on_scalars(self, other, cmp_op); | ||
if (ret != NULL) { | ||
return ret; | ||
} | ||
case -2: | ||
/* use ufunc */ | ||
if (PyErr_Occurred()) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this write place to add it? Any logical grouping of the headers done here? I noticed there are groups
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly, the header includes are probably just a mess all over the place, and I am not sure of much logic... I would probably group it by where the folders (which may mean putting it by itself here). I doubt the rest follows any such logic though. I am willing to bet that most of the order is simply the historical order for when it was necessary to add them...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see. Yeah, I'll make it separate.