8000 BUG,ENH: Fix generic scalar infinite recursion issues by seberg · Pull Request #26904 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

BUG,ENH: Fix generic scalar infinite recursion issues #26904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions numpy/_core/src/multiarray/npy_static_data.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ intern_strings(void)
INTERN_STRING(implementation, "_implementation");
INTERN_STRING(axis1, "axis1");
INTERN_STRING(axis2, "axis2");
INTERN_STRING(item, "item");
INTERN_STRING(like, "like");
INTERN_STRING(numpy, "numpy");
INTERN_STRING(where, "where");
Expand Down
1 change: 1 addition & 0 deletions numpy/_core/src/multiarray/npy_static_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ typedef struct npy_interned_str_struct {
PyObject *implementation;
PyObject *axis1;
PyObject *axis2;
PyObject *item;
PyObject *like;
PyObject *numpy;
PyObject *where;
Expand Down
311 changes: 270 additions & 41 deletions numpy/_core/src/multiarray/scalartypes.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "ctors.h"
#include "dtypemeta.h"
#include "usertypes.h"
#include "number.h"
#include "numpyos.h"
#include "can_cast_table.h"
#include "common.h"
Expand Down Expand Up @@ -120,19 +121,6 @@ gentype_free(PyObject *v)
}


static PyObject *
gentype_power(PyObject *m1, PyObject *m2, PyObject *modulo)
{
if (modulo != Py_None) {
/* modular exponentiation is not implemented (gh-8804) */
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}

BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_power, gentype_power);
return PyArray_Type.tp_as_number->nb_power(m1, m2, Py_None);
}

static PyObject *
gentype_generic_method(PyObject *self, PyObject *args, PyObject *kwds,
char *str)
Expand Down Expand Up @@ -164,33 +152,194 @@ gentype_generic_method(PyObject *self, PyObject *args, PyObject *kwds,
}
}

static PyObject *
gentype_add(PyObject *m1, PyObject* m2)
{
/* special case str.__radd__, which should not call array_add */
if (PyBytes_Check(m1) || PyUnicode_Check(m1)) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;

/*
* Helper function to deal with binary operator deferral. Must be passed a
* valid self (a generic scalar) and an other item.
* May fill self_item and/or other_arr (but not both) with non-NULL values.
*
* Why this dance? When the other object is a exactly Python scalar something
* awkward happens historically in NumPy.
* NumPy doesn't define a result, but the ufunc would cast to `astype(object)`
* which is the same as `scalar.item()`. And that operation converts e.g.
* float32 or float64 to Python floats.
* It then retries. And because it is a builtin type now the operation may
* succeed.
*
* This retrying pass only makes sense if the other object is a Python
* scalar (otherwise we fill in `other_arr` which can be used to call the
* ufunc).
* Additionally, if `self.item()` has the same type as `self` we would end up
* in an infinite recursion.
*
* So the result of this function means the following:
* - < 0 error return.
* - self_op is filled in: Retry the Python operator.
* - other_op is filled in: Use the array operator (goes into ufuncs)
* (This may be the original generic if it is one.)
* - neither is filled in: Return NotImplemented.
*
* It is not possible for both to be filled. If `other` is also a generics,
* it is returned.
*/
static inline int
find_binary_operation_path(
PyObject *self, PyObject *other, PyObject **self_op, PyObject **other_op)
{
*other_op = NULL;
*self_op = NULL;

if (PyArray_IsScalar(other, Generic) ||
PyLong_Check(other) ||
PyFloat_Check(other) ||
PyComplex_Check(other) ||
PyBool_Check(other)) {
/*
* The other operand is ready for the operation already. Must pass on
* on float/long/complex mainly for weak promotion (NEP 50).
*/
Py_INCREF(other);
*other_op = other;
return 0;
}
BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_add, gentype_add);
return PyArray_Type.tp_as_number->nb_add(m1, m2);

/*
* Now check `other`. We want to know whether it is an object scalar
* and the easiest way is by converting to an array here.
*/
int was_scalar;
PyArrayObject *arr = (PyArrayObject *)PyArray_FromAny_int(
other, NULL, NULL, 0, 0, 0, NULL, &was_scalar);
if (arr == NULL) {
return -1;
}

if (!was_scalar || PyArray_DESCR(arr)->type_num != NPY_OBJECT) {
/* The array is OK for usage and we can simply forward it
*
* NOTE: Future NumPy may need to distinguish scalars here, one option
* could be marking the array.
*/
*other_op = (PyObject *)arr;
return 0;
}
Py_DECREF(arr);

/*
* If we are here, we need to operate on Python scalars. In general
* that would just fails since NumPy doesn't know the other object!
*
* However, NumPy (historically) often makes this work magically because
* it object ufuncs end up casting to object with `.item()` and that may
* returns Python type often (e.g. float for float32, float64)!
* Retrying then succeeds. So if (and only if) `self.item()` returns a new
* type, we can safely attempt the operation (again) with that.
*/
PyObject *self_item = PyObject_CallMethodNoArgs(self, npy_interned_str.item);
if (self_item == NULL) {
return -1;
}
if (Py_TYPE(self_item) != Py_TYPE(self)) {
/* self_item can be used to retry the operation */
*self_op = self_item;
return 0;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't you leak self_item if you get to this last return 0? Maybe also in the case where you assign it to self_op, I haven't looked at how reference counting works in the caller.

/* The operation can't work and we will return NotImplemented */
return 0;
}


/*
* These are defined below as they require special handling, we still define
* a _gen version here. `power` is special as it has three arguments.
*/
static PyObject *
gentype_add(PyObject *m1, PyObject *m2);

static PyObject *
gentype_multiply(PyObject *m1, PyObject *m2);


/**begin repeat
*
* #name = subtract, remainder, divmod, lshift, rshift,
* and, xor, or, floor_divide, true_divide#
* #name = add, multiply, subtract, remainder, divmod,
* lshift, rshift, and, xor, or, floor_divide, true_divide#
* #ufunc = add, multiply, subtract, remainder, divmod,
* left_shift, right_shift, bitwise_and, bitwise_xor, bitwise_or,
* floor_divide, true_divide#
* #func = Add, Multiply, Subtract, Remainder, Divmod,
* Lshift, Rshift, And, Xor, Or, FloorDivide, TrueDivide#
* #suff = _gen, _gen,,,,,,,,,,#
*/
/* NOTE: We suffix the name for functions requiring special handling first. */
static PyObject *
gentype_@name@(PyObject *m1, PyObject *m2)
gentype_@name@@suff@(PyObject *m1, PyObject *m2)
{
BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_@name@, gentype_@name@);
return PyArray_Type.tp_as_number->nb_@name@(m1, m2);

PyObject *self = NULL;
PyObject *other = NULL;
PyObject *self_op, *other_op;

if (!PyArray_IsScalar(m2, Generic)) {
self = m1;
other = m2;
}
else {
self = m2;
other = m1;
}
if (find_binary_operation_path(self, other, &self_op, &other_op) < 0) {
return NULL;
}
if (self_op != NULL) {
PyObject *res;
if (self == m1) {
res = PyNumber_@func@(self_op, m2);
}
else {
res = PyNumber_@func@(m1, self_op);
}
Py_DECREF(self_op);
return res;
}
else if (other_op != NULL) {
/* Call the corresponding ufunc (with the array) */
PyObject *res;
if (self == m1) {
res = PyArray_GenericBinaryFunction(m1, other_op, n_ops.@ufunc@);
}
else {
res = PyArray_GenericBinaryFunction(other_op, m2, n_ops.@ufunc@);
}
Py_DECREF(other_op);
return res;
}
else {
assert(other_op == NULL);
Py_RETURN_NOTIMPLEMENTED;
}
}

/**end repeat**/

/* Get a nested slot, or NULL if absent */
/*
* The following operators use the above, but require specialization.
*/

static PyObject *
gentype_add(PyObject *m1, PyObject *m2)
{
/* special case str.__radd__, which should not call array_add */
if (PyBytes_Check(m1) || PyUnicode_Check(m1)) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}

return gentype_add_gen(m1, m2);
}

/* Get a nested slot, or NULL if absent (for multiply implementation) */
#define GET_NESTED_SLOT(type, group, slot) \
((type)->group == NULL ? NULL : (type)->group->slot)

Expand Down Expand Up @@ -219,11 +368,75 @@ gentype_multiply(PyObject *m1, PyObject *m2)
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
/* All normal cases are handled by PyArray's multiply */
BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_multiply, gentype_multiply);
return PyArray_Type.tp_as_number->nb_multiply(m1, m2);

return gentype_multiply_gen(m1, m2);
}


/*
* NOTE: The three argument nature of power requires code duplication here.
*/
static PyObject *
gentype_power(PyObject *m1, PyObject *m2, PyObject *modulo)
{
if (modulo != Py_None) {
/* modular exponentiation is not implemented (gh-8804) */
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}

BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_power, gentype_power);

PyObject *self = NULL;
PyObject *other = NULL;
PyObject *self_op, *other_op;

if (!PyArray_IsScalar(m2, Generic)) {
self = m1;
other = m2;
}
else {
self = m2;
other = m1;
}
if (find_binary_operation_path(self, other, &self_op, &other_op) < 0) {
return NULL;
}
if (self_op != NULL) {
PyObject *res;
if (self == m1) {
res = PyNumber_Power(self_op, m2, Py_None);
}
else {
res = PyNumber_Power(m1, self_op, Py_None);
}
Py_DECREF(self_op);
return res;
}
else if (other_op != NULL) {
/* Call the corresponding ufunc (with the array)
* NOTE: As of NumPy 2.0 there are inconsistencies in array_power
* calling it would fail a (niche) test because an array is
* returned in one of the fast-paths.
* (once NumPy propagates 0-D arrays, this is irrelevant)
*/
PyObject *res;
if (self == m1) {
res = PyArray_GenericBinaryFunction(m1, other_op, n_ops.power);
}
else {
res = PyArray_GenericBinaryFunction(other_op, m2, n_ops.power);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started calling the ufunc directly now. This has advantages and disadvantages, but overall: I don't think it should matter much and besides working around power oddities felt pretty sane.

(Not done yet for the comparison, where it might also make sense eventually)

}
Py_DECREF(other_op);
return res;
}
else {
assert(other_op == NULL);
Py_RETURN_NOTIMPLEMENTED;
}
}


/**begin repeat
* #TYPE = BYTE, UBYTE, SHORT, USHORT, INT, UINT,
* LONG, ULONG, LONGLONG, ULONGLONG#
Expand Down Expand Up @@ -1265,8 +1478,6 @@ static PyNumberMethods gentype_as_number = {
static PyObject *
gentype_richcompare(PyObject *self, PyObject *other, int cmp_op)
{
PyObject *arr, *ret;

/*
* If the other object is None, False is always right. This avoids
* the array None comparison, at least until deprecation it is fixed.
Expand All @@ -1287,17 +1498,35 @@ gentype_richcompare(PyObject *self, PyObject *other, int cmp_op)

RICHCMP_GIVE_UP_IF_NEEDED(self, other);

arr = PyArray_FromScalar(self, NULL);
if (arr == NULL) {
PyObject *self_op;
PyObject *other_op;
if (find_binary_operation_path(self, other, &self_op, &other_op) < 0) {
return NULL;
}
/*
* Call via PyObject_RichCompare to ensure that other.__eq__
* has a chance to run when necessary
*/
ret = PyObject_RichCompare(arr, other, cmp_op);
Py_DECREF(arr);
return ret;

/* We can always just call RichCompare again */
if (other_op != NULL) {
/* If we use richcompare again, need to ensure that one op is array */
self_op = PyArray_FromScalar(self, NULL);
if (self_op == NULL) {
Py_DECREF(other_op);
return NULL;
}
PyObject *res = PyObject_RichCompare(self_op, other_op, cmp_op);
Py_DECREF(self_op);
Py_DECREF(other_op);
return res;
}
else if (self_op != NULL) {
/* Try again, since other is an object scalar and this one mutated */
PyObject *res = PyObject_RichCompare(self_op, other, cmp_op);
Py_DECREF(self_op);
return res;
}
else {
/* Comparison with arbitrary objects cannot be defined. */
Py_RETURN_NOTIMPLEMENTED;
}
}

static PyObject *
Expand Down
Loading
Loading
0