8000 ENH: Fast paths for richcompare by ganesh-k13 · Pull Request #17970 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: Fast paths for richcompare #17970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions benchmarks/benchmarks/bench_scalar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from random import randint
from .common import Benchmark, TYPES1

import numpy as np
Expand Down Expand Up @@ -31,3 +32,12 @@ def time_abs(self, typename):
n = self.num
res = abs(abs(abs(abs(abs(abs(abs(abs(abs(abs(n))))))))))

def time_compare(self, typename):
n = self.num
res = [n == randint(-128, 127) for _ in range(10)]

def time_compare_types(self, typename):
n1 = self.num
for type_lhs in TYPES1:
n2 = np.dtype(type_lhs).type(randint(-128, 127))
res = [n1 == n2 for _ in range(10)]
4 changes: 2 additions & 2 deletions numpy/core/src/multiarray/convert_datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,8 @@ PyArray_CanCastScalar(PyTypeObject *from, PyTypeObject *to)
int fromtype;
int totype;

fromtype = _typenum_fromtypeobj((PyObject *)from, 0);
totype = _typenum_fromtypeobj((PyObject *)to, 0);
fromtype = PyArray_TypeNumFromNumPyScalarType((PyObject *)from, 0);
totype = PyArray_TypeNumFromNumPyScalarType((PyObject *)to, 0);
if (fromtype == NPY_NOTYPE || totype == NPY_NOTYPE) {
return NPY_FALSE;
}
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/scalarapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ NPY_NO_EXPORT PyArray_Descr *
PyArray_DescrFromTypeObject(PyObject *type)
{
/* if it's a builtin type, then use the typenumber */
int typenum = _typenum_fromtypeobj(type,1);
int typenum = PyArray_TypeNumFromNumPyScalarType(type,1);
if (typenum != NPY_NOTYPE) {
return PyArray_DescrFromType(typenum);
}
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/scalartypes.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -4144,7 +4144,7 @@ is_anyscalar_exact(PyObject *obj)
}

NPY_NO_EXPORT int
_typenum_fromtypeobj(PyObject *type, int user)
PyArray_TypeNumFromNumPyScalarType(PyObject *type, int user)
{
int typenum, i;

Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/scalartypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ NPY_NO_EXPORT int
is_anyscalar_exact(PyObject *obj);

NPY_NO_EXPORT int
_typenum_fromtypeobj(PyObject *type, int user);
PyArray_TypeNumFromNumPyScalarType(PyObject *type, int user);

NPY_NO_EXPORT void *
scalar_value(PyObject *scalar, PyArray_Descr *descr);
Expand Down
197 changes: 135 additions & 62 deletions numpy/core/src/umath/scalarmath.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "numpy/ufuncobject.h"
#include "numpy/arrayscalars.h"

#include "scalartypes.h"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this write place to add it? Any logical grouping of the headers done here? I noticed there are groups

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, the header includes are probably just a mess all over the place, and I am not sure of much logic... I would probably group it by where the folders (which may mean putting it by itself here). I doubt the rest follows any such logic though. I am willing to bet that most of the order is simply the historical order for when it was necessary to add them...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. Yeah, I'll make it separate.


#include "npy_import.h"
#include "npy_pycompat.h"

Expand Down Expand Up @@ -569,6 +571,117 @@ static void

/*** END OF BASIC CODE **/

/**
* Find the descriptor for a builtin or Python numerical scalar.
* This function should only be used for scalar math fast paths.
*
* @param obj The object for which to find the descriptor
* @param descr The descriptor that was found.
* @returns 1 if the other object is a Python type and thus must not
* handle the operation, otherwise 0. -1 _without_ an error set
* if no descriptor was found.
*/
static int
descr_from_basic_scalar(PyObject *obj, PyArray_Descr **descr)
{
/* TODO: We could try giving defined scalars a chance... */
int type_num = PyArray_TypeNumFromNumPyScalarType(
(PyObject *)Py_TYPE(obj), 0);
if (type_num != NPY_NOTYPE) {
*descr = PyArray_DescrFromType(type_num);
return 0;
}
else if (PyFloat_CheckExact(obj)) {
*descr = PyArray_DescrFromType(NPY_DOUBLE);
return 1;
}
else if (PyBool_Check(obj)) {
*descr = PyArray_DescrFromType(NPY_BOOL);
return 1;
}
else if (PyLong_CheckExact(obj)) {
if (PyLong_AsLong(obj) == -1 && PyErr_Occurred()) {
PyErr_Clear();
return -1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess codecov is right, it would be nice to test this path by passing in a 2**70 python integer or so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @seberg , any idea how we can hit this reliably? We don't reach case 2 in case of raw int, at the same time, we cannot overflow a Numpy scalar while creating it.

}
*descr = PyArray_DescrFromType(NPY_LONG);
return 1;
}
else if (PyComplex_CheckExact(obj)) {
*descr = PyArray_DescrFromType(NPY_CDOUBLE);
return 1;
}
return -1;
}

#define PyDescr_PythonRepresentable(d) (d->type_num < NPY_LONGDOUBLE || \
d->type_num == NPY_CFLOAT || d->type_num == NPY_CDOUBLE)

/**
* This function attempts to compare NumPy or Python scalars
* The operation is done by getting value of both scalars
* and calling richcompare on them.
*
* @param self The first object which will be converted to scalar item
* @param other The second object to compare to
* @param cmp_op The value of comparison operator
* @returns Py_False or Py_True if richcompare is successfull
* , otherwise NULL.
*/
static PyObject*
do_richcompare_on_scalars(PyObject *self, PyObject *other, int cmp_op) {
PyObject *cmp_item_self, *cmp_item_other, *ret=NULL;
PyArray_Descr *self_descr=NULL, *other_descr=NULL;
void *data_self, *data_other;
int pyscalar_other, pyscalar_self;
int is_complex_operands, is_equality_operator, is_python_representable;

pyscalar_self = descr_from_basic_scalar(self, &self_descr);
pyscalar_other = descr_from_basic_scalar(other, &other_descr);

if (pyscalar_self >= 0 && pyscalar_other >= 0) {
/*
* If either of the operands are complex and operator is not equality,
* or the operands are not representable as a native type,
* python's built-in richcompare cannot be used as it is not supported.
*/
is_complex_operands = (PyTypeNum_ISCOMPLEX(self_descr->type_num) ||
PyTypeNum_ISCOMPLEX(other_descr->type_num));
is_equality_operator = (cmp_op != Py_EQ && cmp_op != Py_NE);
is_python_representable = PyDescr_PythonRepresentable(self_descr) &&
PyDescr_PythonRepresentable(other_descr);

if (!(is_complex_operands && is_equality_operator) && is_python_representable) {
/*
* If the scalar is a python built-in, we can use the object as is.
* Else we need to obtain the value from the operand.
* Note: If we reach this point, one of the scalars must be built-in.
*/
data_self = scalar_value(self, NULL);
data_other = scalar_value(other, NULL);
cmp_item_self = pyscalar_self == 0 ? self_descr->f->getitem(data_self, NULL):
self;
cmp_item_other = pyscalar_other == 0 ? other_descr->f->getitem(data_other, NULL):
other;

if (cmp_item_self != NULL && cmp_item_other != NULL) {
ret = PyObject_RichCompare(cmp_item_self, cmp_item_other, cmp_op);
}

if (pyscalar_self == 0) {
Py_XDECREF(cmp_item_self);
}
if (pyscalar_other == 0) {
Py_XDECREF(cmp_item_other);
}
}
}
Py_XDECREF(self_descr);
Py_XDECREF(other_descr);

return ret;
}


/* The general strategy for commutative binary operators is to
*
Expand All @@ -587,86 +700,38 @@ static void
/**begin repeat
* #name = byte, ubyte, short, ushort, int, uint,
* long, ulong, longlong, ulonglong,
* half, float, longdouble,
* half, float, double, longdouble,
* cfloat, cdouble, clongdouble#
* #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
* npy_long, npy_ulong, npy_longlong, npy_ulonglong,
* npy_half, npy_float, npy_longdouble,
* npy_half, npy_float, npy_double, npy_longdouble,
* npy_cfloat, npy_cdouble, npy_clongdouble#
* #Name = Byte, UByte, Short, UShort, Int, UInt,
* Long, ULong, LongLong, ULongLong,
* Half, Float, LongDouble,
* Half, Float, Double, LongDouble,
* CFloat, CDouble, CLongDouble#
* #TYPE = NPY_BYTE, NPY_UBYTE, NPY_SHORT, NPY_USHORT, NPY_INT, NPY_UINT,
* NPY_LONG, NPY_ULONG, NPY_LONGLONG, NPY_ULONGLONG,
* NPY_HALF, NPY_FLOAT, NPY_LONGDOUBLE,
* NPY_HALF, NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
* NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE#
*/

static int
_@name@_convert_to_ctype(PyObject *a, @type@ *arg1)
{
PyObject *temp;

if (PyArray_IsScalar(a, @Name@)) {
*arg1 = PyArrayScalar_VAL(a, @Name@);
return 0;
}
else if (PyArray_IsScalar(a, Generic)) {
PyArray_Descr *descr1;

if (!PyArray_IsScalar(a, Number)) {
return -1;
}
descr1 = PyArray_DescrFromTypeObject((PyObject *)Py_TYPE(a));
if (PyArray_CanCastSafely(descr1->type_num, @TYPE@)) {
PyArray_CastScalarDirect(a, descr1, arg1, @TYPE@);
Py_DECREF(descr1);
return 0;
}
else {
Py_DECREF(descr1);
return -1;
}
}
else if (PyArray_GetPriority(a, NPY_PRIORITY) > NPY_PRIORITY) {
return -2;
}
else if ((temp = PyArray_ScalarFromObject(a)) != NULL) {
int retval = _@name@_convert_to_ctype(temp, arg1);

Py_DECREF(temp);
return retval;
}
return -2;
}

/**end repeat**/


/* Same as above but added exact checks against known python types for speed */

/**begin repeat
* #name = double#
* #type = npy_double#
* #Name = Double#
* #TYPE = NPY_DOUBLE#
* #PYCHECKEXACT = PyFloat_CheckExact#
* #PYEXTRACTCTYPE = PyFloat_AS_DOUBLE#
*/
#define _IS_@Name@

static int
_@name@_convert_to_ctype(PyObject *a, @type@ *arg1)
_@name@_convert_to_ctype(PyObject *a, @type@ *arg)
{
PyObject *temp;

if (@PYCHECKEXACT@(a)){
*arg1 = @PYEXTRACTCTYPE@(a);
#if defined(_IS_Double) || defined(_IS_LongDouble)
if (PyFloat_CheckExact(a)) {
*arg = (@type@)PyFloat_AS_DOUBLE(a);
return 0;
}
#endif

if (PyArray_IsScalar(a, @Name@)) {
*arg1 = PyArrayScalar_VAL(a, @Name@);
*arg = PyArrayScalar_VAL(a, @Name@);
return 0;
}
else if (PyArray_IsScalar(a, Generic)) {
Expand All @@ -677,7 +742,7 @@ _@name@_convert_to_ctype(PyObject *a, @type@ *arg1)
}
descr1 = PyArray_DescrFromTypeObject((PyObject *)Py_TYPE(a));
if (PyArray_CanCastSafely(descr1->type_num, @TYPE@)) {
PyArray_CastScalarDirect(a, descr1, arg1, @TYPE@);
PyArray_CastScalarDirect(a, descr1, arg, @TYPE@);
Py_DECREF(descr1);
return 0;
}
Expand All @@ -690,14 +755,16 @@ _@name@_convert_to_ctype(PyObject *a, @type@ *arg1)
return -2;
}
else if ((temp = PyArray_ScalarFromObject(a)) != NULL) {
int retval = _@name@_convert_to_ctype(temp, arg1);
int retval = _@name@_convert_to_ctype(temp, arg);

Py_DECREF(temp);
return retval;
}
return -2;
}

#undef _IS_@Name@

/**end repeat**/


Expand Down Expand Up @@ -1308,14 +1375,20 @@ static PyObject*
{
npy_@name@ arg1, arg2;
int out=0;
PyObject *ret;

RICHCMP_GIVE_UP_IF_NEEDED(self, other);

switch(_@name@_convert2_to_ctypes(self, &arg1, other, &arg2)) {
case 0:
break;
case -1:
/* can't cast both safely use different add function */
/* can't cast both safely to same type.
* Try fastpath else use ufuncs */
ret = do_richcompare_on_scalars(self, other, cmp_op);
if (ret != NULL) {
return ret;
}
case -2:
/* use ufunc */
if (PyErr_Occurred()) {
Expand Down
12 changes: 12 additions & 0 deletions numpy/core/tests/test_scalarmath.py
7297
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
np.int_, np.uint, np.longlong, np.ulonglong,
np.single, np.double, np.longdouble, np.csingle,
np.cdouble, np.clongdouble]
all_numbers_dtypes = np.typecodes['AllInteger'] + np.typecodes['AllFloat'] +\
np.typecodes['Complex']

floating_types = np.floating.__subclasses__()
complex_floating_types = np.complexfloating.__subclasses__()
Expand Down Expand Up @@ -707,3 +709,13 @@ def test_shift_all_bits(self, type_code, op):
shift_arr = np.array([shift]*32, dtype=dt)
res_arr = op(val_arr, shift_arr)
assert_equal(res_arr, res_scl)

class TestComparison:
@pytest.mark.parametrize('type_code_rhs', all_numbers_dtypes)
@pytest.mark.parametrize('type_code_lhs', all_numbers_dtypes)
def test_numbers_compare(self, type_code_rhs, type_code_lhs):
rand_num = np.random.randint(0, 127)
a = np.dtype(type_code_rhs).type(rand_num)
b = np.dtype(type_code_lhs).type(rand_num)

assert_almost_equal(a, b)
0