8000 ENH: Fast paths for richcompare by shubham11941140 · Pull Request #19720 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: Fast paths for richcompare #19720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions benchmarks/benchmarks/bench_scalar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from random import randint
from .common import Benchmark, TYPES1

import numpy as np
Expand Down Expand Up @@ -31,3 +32,12 @@ def time_abs(self, typename):
n = self.num
res = abs(abs(abs(abs(abs(abs(abs(abs(abs(abs(n))))))))))

def time_compare(self, typename):
n = self.num
res = [n == randint(-128, 127) for _ in range(10)]

def time_compare_types(self, typename):
n1 = self.num
for type_lhs in TYPES1:
n2 = np.dtype(type_lhs).type(randint(-128, 127))
res = [n1 == n2 for _ in range(10)]
4 changes: 2 additions & 2 deletions numpy/core/src/multiarray/convert_datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,8 @@ PyArray_CanCastScalar(PyTypeObject *from, PyTypeObject *to)
int fromtype;
int totype;

fromtype = _typenum_fromtypeobj((PyObject *)from, 0);
totype = _typenum_fromtypeobj((PyObject *)to, 0);
fromtype = PyArray_TypeNumFromNumPyScalarType((PyObject *)from, 0);
totype = PyArray_TypeNumFromNumPyScalarType((PyObject *)to, 0);
if (fromtype == NPY_NOTYPE || totype == NPY_NOTYPE) {
return NPY_FALSE;
}
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/scalarapi.c
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ NPY_NO_EXPORT PyArray_Descr *
PyArray_DescrFromTypeObject(PyObject *type)
{
/* if it's a builtin type, then use the typenumber */
int typenum = _typenum_fromtypeobj(type,1);
int typenum = PyArray_TypeNumFromNumPyScalarType(type,1);
if (typenum != NPY_NOTYPE) {
return PyArray_DescrFromType(typenum);
}
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/scalartypes.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -4145,7 +4145,7 @@ is_anyscalar_exact(PyObject *obj)
}

NPY_NO_EXPORT int
_typenum_fromtypeobj(PyObject *type, int user)
PyArray_TypeNumFromNumPyScalarType(PyObject *type, int user)
{
int typenum, i;

Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/scalartypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ NPY_NO_EXPORT int
is_anyscalar_exact(PyObject *obj);

NPY_NO_EXPORT int
_typenum_fromtypeobj(PyObject *type, int user);
PyArray_TypeNumFromNumPyScalarType(PyObject *type, int user);

NPY_NO_EXPORT void *
scalar_value(PyObject *scalar, PyArray_Descr *descr);
Expand Down
206 changes: 140 additions & 66 deletions numpy/core/src/umath/scalarmath.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "numpy/ufuncobject.h"
#include "numpy/arrayscalars.h"

#include "scalartypes.h"

#include "npy_import.h"
#include "npy_pycompat.h"

Expand Down Expand Up @@ -563,6 +565,117 @@ static void

/*** END OF BASIC CODE **/

/**
* Find the descriptor for a builtin or Python numerical scalar.
* This function should only be used for scalar math fast paths.
*
* @param obj The object for which to find the descriptor
* @param descr The descriptor that was found.
* @returns 1 if the other object is a Python type and thus must not
* handle the operation, otherwise 0. -1 _without_ an error set
* if no descriptor was found.
*/
static int
descr_from_basic_scalar(PyObject *obj, PyArray_Descr **descr)
{
/* TODO: We could try giving defined scalars a chance... */
int type_num = PyArray_TypeNumFromNumPyScalarType(
(PyObject *)Py_TYPE(obj), 0);
if (type_num != NPY_NOTYPE) {
*descr = PyArray_DescrFromType(type_num);
return 0;
}
else if (PyFloat_CheckExact(obj)) {
*descr = PyArray_DescrFromType(NPY_DOUBLE);
return 1;
}
else if (PyBool_Check(obj)) {
*descr = PyArray_DescrFromType(NPY_BOOL);
return 1;
}
else if (PyLong_CheckExact(obj)) {
if (PyLong_AsLong(obj) == -1 && PyErr_Occurred()) {
PyErr_Clear();
return -1;
}
*descr = PyArray_DescrFromType(NPY_LONG);
return 1;
}
else if (PyComplex_CheckExact(obj)) {
*descr = PyArray_DescrFromType(NPY_CDOUBLE);
return 1;
}
return -1;
}

#define PyDescr_PythonRepresentable(d) (d->type_num < NPY_LONGDOUBLE || \
d->type_num == NPY_CFLOAT || d->type_num == NPY_CDOUBLE)

/**
* This function attempts to compare NumPy or Python scalars
* The operation is done by getting value of both scalars
* and calling richcompare on them.
*
* @param self The first object which will be converted to scalar item
* @param other The second object to compare to
* @param cmp_op The value of comparison operator
* @returns Py_False or Py_True if richcompare is successfull
* , otherwise NULL.
*/
static PyObject*
do_richcompare_on_scalars(PyObject *self, PyObject *other, int cmp_op) {
PyObject *cmp_item_self, *cmp_item_other, *ret=NULL;
PyArray_Descr *self_descr=NULL, *other_descr=NULL;
void *data_self, *data_other;
int pyscalar_other, pyscalar_self;
int is_complex_operands, is_equality_operator, is_python_representable;

pyscalar_self = descr_from_basic_scalar(self, &self_descr);
pyscalar_other = descr_from_basic_scalar(other, &other_descr);

if (pyscalar_self >= 0 && pyscalar_other >= 0) {
/*
* If either of the operands are complex and operator is not equality,
* or the operands are not representable as a native type,
* python's built-in richcompare cannot be used as it is not supported.
*/
is_complex_operands = (PyTypeNum_ISCOMPLEX(self_descr->type_num) ||
PyTypeNum_ISCOMPLEX(other_descr->type_num));
is_equality_operator = (cmp_op != Py_EQ && cmp_op != Py_NE);
is_python_representable = PyDescr_PythonRepresentable(self_descr) &&
PyDescr_PythonRepresentable(other_descr);

if (!(is_complex_operands && is_equality_operator) && is_python_representable) {
/*
* If the scalar is a python built-in, we can use the object as is.
* Else we need to obtain the value from the operand.
* Note: If we reach this point, one of the scalars must be built-in.
*/
data_self = scalar_value(self, NULL);
data_other = scalar_value(other, NULL);
cmp_item_self = pyscalar_self == 0 ? self_descr->f->getitem(data_self, NULL):
self;
cmp_item_other = pyscalar_other == 0 ? other_descr->f->getitem(data_other, NULL):
other;

if (cmp_item_self != NULL && cmp_item_other != NULL) {
ret = PyObject_RichCompare(cmp_item_self, cmp_item_other, cmp_op);
}

if (pyscalar_self == 0) {
Py_XDECREF(cmp_item_self);
}
if (pyscalar_other == 0) {
Py_XDECREF(cmp_item_other);
}
}
}
Py_XDECREF(self_descr);
Py_XDECREF(other_descr);

return ret;
}


/* The general strategy for commutative binary operators is to
*
Expand All @@ -581,86 +694,39 @@ static void
/**begin repeat
* #name = byte, ubyte, short, ushort, int, uint,
* long, ulong, longlong, ulonglong,
* half, float, longdouble,
* half, float, double, longdouble,
* cfloat, cdouble, clongdouble#
* #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
* npy_long, npy_ulong, npy_longlong, npy_ulonglong,
* npy_half, npy_float, npy_longdouble,
* npy_half, npy_float, npy_double, npy_longdouble,
* npy_cfloat, npy_cdouble, npy_clongdouble#
* #Name = Byte, UByte, Short, UShort, Int, UInt,
* Long, ULong, LongLong, ULongLong,
* Half, Float, LongDouble,
* Half, Float, Double, LongDouble,
* CFloat, CDouble, CLongDouble#
* #TYPE = NPY_BYTE, NPY_UBYTE, NPY_SHORT, NPY_USHORT, NPY_INT, NPY_UINT,
* NPY_LONG, NPY_ULONG, NPY_LONGLONG, NPY_ULONGLONG,
* NPY_HALF, NPY_FLOAT, NPY_LONGDOUBLE,
* NPY_HALF, NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
* NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE#
*/

static int
_@name@_convert_to_ctype(PyObject *a, @type@ *arg1)
{
PyObject *temp;

if (PyArray_IsScalar(a, @Name@)) {
*arg1 = PyArrayScalar_VAL(a, @Name@);
return 0;
}
else if (PyArray_IsScalar(a, Generic)) {
PyArray_Descr *descr1;

if (!PyArray_IsScalar(a, Number)) {
return -1;
}
descr1 = PyArray_DescrFromTypeObject((PyObject *)Py_TYPE(a));
if (PyArray_CanCastSafely(descr1->type_num, @TYPE@)) {
PyArray_CastScalarDirect(a, descr1, arg1, @TYPE@);
Py_DECREF(descr1);
return 0;
}
else {
Py_DECREF(descr1);
return -1;
}
}
else if (PyArray_GetPriority(a, NPY_PRIORITY) > NPY_PRIORITY) {
return -2;
}
else if ((temp = PyArray_ScalarFromObject(a)) != NULL) {
int retval = _@name@_convert_to_ctype(temp, arg1);

Py_DECREF(temp);
return retval;
}
return -2;
}

/**end repeat**/


/* Same as above but added exact checks against known python types for speed */

/**begin repeat
* #name = double#
* #type = npy_double#
* #Name = Double#
* #TYPE = NPY_DOUBLE#
* #PYCHECKEXACT = PyFloat_CheckExact#
* #PYEXTRACTCTYPE = PyFloat_AS_DOUBLE#
*/
#define _IS_@Name@

static int
_@name@_convert_to_ctype(PyObject *a, @type@ *arg1)
_@name@_convert_to_ctype(PyObject *a, @type@ *arg)
{
PyObject *temp;

if (@PYCHECKEXACT@(a)){
*arg1 = @PYEXTRACTCTYPE@(a);
#if defined(_IS_Double) || defined(_IS_LongDouble)
if (PyFloat_CheckExact(a)) {
*arg = (@type@)PyFloat_AS_DOUBLE(a);
return 0;
}
#endif

if (PyArray_IsScalar(a, @Name@)) {
*arg1 = PyArrayScalar_VAL(a, @Name@);
*arg = PyArrayScalar_VAL(a, @Name@);
return 0;
}
else if (PyArray_IsScalar(a, Generic)) {
Expand All @@ -671,7 +737,7 @@ _@name@_convert_to_ctype(PyObject *a, @type@ *arg1)
}
descr1 = PyArray_DescrFromTypeObject((PyObject *)Py_TYPE(a));
if (PyArray_CanCastSafely(descr1->type_num, @TYPE@)) {
PyArray_CastScalarDirect(a, descr1, arg1, @TYPE@);
PyArray_CastScalarDirect(a, descr1, arg, @TYPE@);
Py_DECREF(descr1);
return 0;
}
Expand All @@ -684,14 +750,16 @@ _@name@_convert_to_ctype(PyObject *a, @type@ *arg1)
return -2;
}
else if ((temp = PyArray_ScalarFromObject(a)) != NULL) {
int retval = _@name@_convert_to_ctype(temp, arg1);
int retval = _@name@_convert_to_ctype(temp, arg);

Py_DECREF(temp);
return retval;
}
return -2;
}

#undef _IS_@Name@

/**end repeat**/


Expand Down Expand Up @@ -1048,11 +1116,11 @@ static PyObject *
*
*/

/*
Complex numbers do not support remainder operations. Unfortunately,
the type inference for long doubles is complicated, and if a remainder
operation is not defined - if the relevant field is left NULL - then
operations between long doubles and objects lead to an infinite recursion
/*
Complex numbers do not support remainder operations. Unfortunately,
the type inference for long doubles is complicated, and if a remainder
operation is not defined - if the relevant field is left NULL - then
operations between long doubles and objects lead to an infinite recursion
instead of a TypeError. This should ensure that once everything gets
converted to complex long doubles you correctly get a reasonably
informative TypeError. This fixes the last part of bug gh-18548.
Expand Down Expand Up @@ -1359,14 +1427,20 @@ static PyObject*
{
npy_@name@ arg1, arg2;
int out=0;
PyObject *ret;

RICHCMP_GIVE_UP_IF_NEEDED(self, other);

switch(_@name@_convert2_to_ctypes(self, &arg1, other, &arg2)) {
case 0:
break;
case -1:
/* can't cast both safely use different add function */
/* can't cast both safely to same type.
* Try fastpath else use ufuncs */
ret = do_richcompare_on_scalars(self, other, cmp_op);
if (ret != NULL) {
return ret;
}
case -2:
/* use ufunc */
if (PyErr_Occurred()) {
Expand Down
0