8000 Merge pull request #9021 from charris/fix-inplace-operators · numpy/numpy@af54fbd · GitHub
[go: up one dir, main page]

Skip to content

Commit af54fbd

Browse files
authored
Merge pull request #9021 from charris/fix-inplace-operators
BUG: Make ndarray inplace operators forward calls when needed.
2 parents ef2cfe1 + 105557a commit af54fbd

File tree

3 files changed

+124
-69
lines changed

3 files changed

+124
-69
lines changed

numpy/core/src/multiarray/number.c

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,38 @@
2222

2323
NPY_NO_EXPORT NumericOps n_ops; /* NB: static objects initialized to zero */
2424

25+
/*
26+
* Forward declarations. Might want to move functions around instead
27+
*/
28+
static PyObject *
29+
array_inplace_add(PyArrayObject *m1, PyObject *m2);
30+
static PyObject *
31+
array_inplace_subtract(PyArrayObject *m1, PyObject *m2);
32+
static PyObject *
33+
array_inplace_multiply(PyArrayObject *m1, PyObject *m2);
34+
#if !defined(NPY_PY3K)
35+
static PyObject *
36+
array_inplace_divide(PyArrayObject *m1, PyObject *m2);
37+
#endif
38+
static PyObject *
39+
array_inplace_true_divide(PyArrayObject *m1, PyObject *m2);
40+
static PyObject *
41+
array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2);
42+
static PyObject *
43+
array_inplace_bitwise_and(PyArrayObject *m1, PyObject *m2);
44+
static PyObject *
45+
array_inplace_bitwise_or(PyArrayObject *m1, PyObject *m2);
46+
static PyObject *
47+
array_inplace_bitwise_xor(PyArrayObject *m1, PyObject *m2);
48+
static PyObject *
49+
array_inplace_left_shift(PyArrayObject *m1, PyObject *m2);
50+
static PyObject *
51+
array_inplace_right_shift(PyArrayObject *m1, PyObject *m2);
52+
static PyObject *
53+
array_inplace_remainder(PyArrayObject *m1, PyObject *m2);
54+
static PyObject *
55+
array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo));
56+
2557
/*
2658
* Dictionary can contain any of the numeric operations, by name.
2759
* Those not present will not be changed
@@ -255,31 +287,6 @@ PyArray_GenericInplaceUnaryFunction(PyArrayObject *m1, PyObject *op)
255287
return PyObject_CallFunctionObjArgs(op, m1, m1, NULL);
256288
}
257289

258-
static PyObject *
259-
array_inplace_add(PyArrayObject *m1, PyObject *m2);
260-
static PyObject *
261-
array_inplace_subtract(PyArrayObject *m1, PyObject *m2);
262-
static PyObject *
263-
array_inplace_multiply(PyArrayObject *m1, PyObject *m2);
264-
#if !defined(NPY_PY3K)
265-
static PyObject *
266-
array_inplace_divide(PyArrayObject *m1, PyObject *m2);
267-
#endif
268-
static PyObject *
269-
array_inplace_true_divide(PyArrayObject *m1, PyObject *m2);
270-
static PyObject *
271-
array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2);
272-
static PyObject *
273-
array_inplace_bitwise_and(PyArrayObject *m1, PyObject *m2);
274-
static PyObject *
275-
array_inplace_bitwise_or(PyArrayObject *m1, PyObject *m2);
276-
static PyObject *
277-
array_inplace_bitwise_xor(PyArrayObject *m1, PyObject *m2);
278-
static PyObject *
279-
array_inplace_left_shift(PyArrayObject *m1, PyObject *m2);
280-
static PyObject *
281-
array_inplace_right_shift(PyArrayObject *m1, PyObject *m2);
282-
283290
static PyObject *
284291
array_add(PyArrayObject *m1, PyObject *m2)
285292
{
@@ -628,32 +635,42 @@ array_bitwise_xor(PyArrayObject *m1, PyObject *m2)
628635
static PyObject *
629636
array_inplace_add(PyArrayObject *m1, PyObject *m2)
630637
{
638+
INPLACE_GIVE_UP_IF_NEEDED(
639+
m1, m2, nb_inplace_add, array_inplace_add);
631640
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.add);
632641
}
633642

634643
static PyObject *
635644
array_inplace_subtract(PyArrayObject *m1, PyObject *m2)
636645
{
646+
INPLACE_GIVE_UP_IF_NEEDED(
647+
m1, m2, nb_inplace_subtract, array_inplace_subtract);
637648
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.subtract);
638649
}
639650

640651
static PyObject *
641652
array_inplace_multiply(PyArrayObject *m1, PyObject *m2)
642653
{
654+
INPLACE_GIVE_UP_IF_NEEDED(
655+
m1, m2, nb_inplace_multiply, array_inplace_multiply);
643656
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.multiply);
644657
}
645658

646659
#if !defined(NPY_PY3K)
647660
static PyObject *
648661
array_inplace_divide(PyArrayObject *m1, PyObject *m2)
649662
{
663+
INPLACE_GIVE_UP_IF_NEEDED(
664+
m1, m2, nb_inplace_divide, array_inplace_divide);
650665
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.divide);
651666
}
652667
#endif
653668

654669
static PyObject *
655670
array_inplace_remainder(PyArrayObject *m1, PyObject *m2)
656671
{
672+
INPLACE_GIVE_UP_IF_NEEDED(
673+
m1, m2, nb_inplace_remainder, array_inplace_remainder);
657674
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.remainder);
658675
}
659676

@@ -662,6 +679,9 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo
662679
{
663680
/* modulo is ignored! */
664681
PyObject *value;
682+
683+
INPLACE_GIVE_UP_IF_NEEDED(
684+
a1, o2, nb_inplace_power, array_inplace_power);
665685
value = fast_scalar_power(a1, o2, 1);
666686
if (!value) {
667687
value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power);
@@ -672,30 +692,40 @@ array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo
672692
static PyObject *
673693
array_inplace_left_shift(PyArrayObject *m1, PyObject *m2)
674694
{
695+
INPLACE_GIVE_UP_IF_NEEDED(
696+
m1, m2, nb_inplace_lshift, array_inplace_left_shift);
675697
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.left_shift);
676698
}
677699

678700
static PyObject *
679701
array_inplace_right_shift(PyArrayObject *m1, PyObject *m2)
680702
{
703+
INPLACE_GIVE_UP_IF_NEEDED(
704+
m1, m2, nb_inplace_rshift, array_inplace_right_shift);
681705
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.right_shift);
682706
}
683707

684708
static PyObject *
685709
array_inplace_bitwise_and(PyArrayObject *m1, PyObject *m2)
686710
{
711+
INPLACE_GIVE_UP_IF_NEEDED(
712+
m1, m2, nb_inplace_and, array_inplace_bitwise_and);
687713
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_and);
688714
}
689715

690716
static PyObject *
691717
array_inplace_bitwise_or(PyArrayObject *m1, PyObject *m2)
692718
{
719+
INPLACE_GIVE_UP_IF_NEEDED(
720+
m1, m2, nb_inplace_or, array_inplace_bitwise_or);
693721
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_or);
694722
}
695723

696724
static PyObject *
697725
array_inplace_bitwise_xor(PyArrayObject *m1, PyObject *m2)
698726
{
727+
INPLACE_GIVE_UP_IF_NEEDED(
728+
m1, m2, nb_inplace_xor, array_inplace_bitwise_xor);
699729
return PyArray_GenericInplaceBinaryFunction(m1, m2, n_ops.bitwise_xor);
700730
}
701731

@@ -728,13 +758,17 @@ array_true_divide(PyArrayObject *m1, PyObject *m2)
728758
static PyObject *
729759
array_inplace_floor_divide(PyArrayObject *m1, PyObject *m2)
730760
{
761+
INPLACE_GIVE_UP_IF_NEEDED(
762+
m1, m2, nb_inplace_floor_divide, array_inplace_floor_divide);
731763
return PyArray_GenericInplaceBinaryFunction(m1, m2,
732764
n_ops.floor_divide);
733765
}
734766

735767
static PyObject *
736768
array_inplace_true_divide(PyArrayObject *m1, PyObject *m2)
737769
{
770+
INPLACE_GIVE_UP_IF_NEEDED(
771+
m1, m2, nb_inplace_true_divide, array_inplace_true_divide);
738772
return PyArray_GenericInplaceBinaryFunction(m1, m2,
739773
n_ops.true_divide);
740774
}
@@ -754,8 +788,8 @@ _array_nonzero(PyArrayObject *mp)
754788
}
755789
else {
756790
PyErr_SetString(PyExc_ValueError,
757-
"The truth value of an array " \
758-
"with more than one element is ambiguous. " \
791+
"The truth value of an array "
792+
"with more than one element is ambiguous. "
759793
"Use a.any() or a.all()");
760794
return -1;
761795
}
@@ -1060,19 +1094,19 @@ NPY_NO_EXPORT PyNumberMethods array_as_number = {
10601094
* This code adds augmented assignment functionality
10611095
* that was made available in Python 2.0
10621096
*/
1063-
(binaryfunc)array_inplace_add, /*inplace_add*/
1064-
(binaryfunc)array_inplace_subtract, /*inplace_subtract*/
1065-
(binaryfunc)array_inplace_multiply, /*inplace_multiply*/
1097+
(binaryfunc)array_inplace_add, /*nb_inplace_add*/
1098+
(binaryfunc)array_inplace_subtract, /*nb_inplace_subtract*/
1099+
(binaryfunc)array_inplace_multiply, /*nb_inplace_multiply*/
10661100
#if !defined(NPY_PY3K)
1067-
(binaryfunc)array_inplace_divide, /*inplace_divide*/
1101+
(binaryfunc)array_inplace_divide, /*nb_inplace_divide*/
10681102
#endif
1069-
(binaryfunc)array_inplace_remainder, /*inplace_remainder*/
1070-
(ternaryfunc)array_inplace_power, /*inplace_power*/
1071-
(binaryfunc)array_inplace_left_shift, /*inplace_lshift*/
1072-
(binaryfunc)array_inplace_right_shift, /*inplace_rshift*/
1073-
(binaryfunc)array_inplace_bitwise_and, /*inplace_and*/
1074-
(binaryfunc)array_inplace_bitwise_xor, /*inplace_xor*/
1075-
(binaryfunc)array_inplace_bitwise_or, /*inplace_or*/
1103+
(binaryfunc)array_inplace_remainder, /*nb_inplace_remainder*/
1104+
(ternaryfunc)array_inplace_power, /*nb_inplace_power*/
1105+
(binaryfunc)array_inplace_left_shift, /*nb_inplace_lshift*/
1106+
(binaryfunc)array_inplace_right_shift, /*nb_inplace_rshift*/
1107+
(binaryfunc)array_inplace_bitwise_and, /*nb_inplace_and*/
1108+
(binaryfunc)array_inplace_bitwise_xor, /*nb_inplace_xor*/
1109+
(binaryfunc)array_inplace_bitwise_or, /*nb_inplace_or*/
10761110

10771111
(binaryfunc)array_floor_divide, /*nb_floor_divide*/
10781112
(binaryfunc)array_true_divide, /*nb_true_divide*/

numpy/core/src/private/binop_override.h

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,22 @@
5050
* where setting a special-method name to None is a signal that that method
5151
* cannot be used.
5252
*
53-
* So for 1.13, we are going to try the following rules. a.__add__(b) will
54-
* be implemented as follows:
53+
* So for 1.13, we are going to try the following rules.
54+
*
55+
* For binops like a.__add__(b):
5556
* - If b does not define __array_ufunc__, apply the legacy rule:
5657
* - If not isinstance(b, a.__class__), and b.__array_priority__ is higher
5758
* than a.__array_priority__, return NotImplemented
5859
* - If b does define __array_ufunc__ but it is None, return NotImplemented
5960
* - Otherwise, call the corresponding ufunc.
6061
*
61-
* For reversed operations like b.__radd__(a), and for in-place operations
62-
* like a.__iadd__(b), we:
63-
* - Call the corresponding ufunc
62+
* For in-place operations like a.__iadd__(b)
63+
* - If b does not define __array_ufunc__, apply the legacy rule:
64+
* - If not isinstance(b, a.__class__), and b.__array_priority__ is higher
65+
* than a.__array_priority__, return NotImplemented
66+
* - Otherwise, call the corresponding ufunc.
67+
*
68+
* For reversed operations like b.__radd__(a) we call the corresponding ufunc.
6469
*
6570
* Rationale for __radd__: This is because by the time the reversed operation
6671
* is called, there are only two possibilities: The first possibility is that
@@ -77,16 +82,19 @@
7782
* above, because if __iadd__ returns NotImplemented then Python will silently
7883
* convert the operation into an out-of-place operation, i.e. 'a += b' will
7984
* silently become 'a = a + b'. We don't want to allow this for arrays,
80-
* because it will create unexpected memory allocations, break views,
81-
* etc.
85+
* because it will create unexpected memory allocations, break views, etc.
86+
* However, backwards compatibility requires that we follow the rules of
87+
* __array_priority__ for arrays that define it. For classes that use the new
88+
* __array_ufunc__ mechanism we simply defer to the ufunc. That has the effect
89+
* that when the other array has__array_ufunc = None a TypeError will be raised.
8290
*
8391
* In the future we might change these rules further. For example, we plan to
8492
* eventually deprecate __array_priority__ in cases where __array_ufunc__ is
8593
* not present.
8694
*/
8795

8896
static int
89-
binop_override_forward_binop_should_defer(PyObject *self, PyObject *other)
97+
binop_should_defer(PyObject *self, PyObject *other, int inplace)
9098
{
9199
/*
92100
* This function assumes that self.__binop__(other) is underway and
@@ -123,7 +131,7 @@ binop_override_forward_binop_should_defer(PyObject *self, PyObject *other)
123131
*/
124132
attr = PyArray_GetAttrString_SuppressException(other, "__array_ufunc__");
125133
if (attr) {
126-
defer = (attr == Py_None);
134+
defer = !inplace && (attr == Py_None);
127135
Py_DECREF(attr);
128136
return defer;
129137
}
@@ -171,7 +179,16 @@ binop_override_forward_binop_should_defer(PyObject *self, PyObject *other)
171179
#define BINOP_GIVE_UP_IF_NEEDED(m1, m2, slot_expr, test_func) \
172180
do { \
173181
if (BINOP_IS_FORWARD(m1, m2, slot_expr, test_func) && \
174-
binop_override_forward_binop_should_defer((PyObject*)m1, (PyObject*)m2)) { \
182+
binop_should_defer((PyObject*)m1, (PyObject*)m2, 0)) { \
183+
Py_INCREF(Py_NotImplemented); \
184+
return Py_NotImplemented; \
185+
} \
186+
} while (0)
187+
188+
#define INPLACE_GIVE_UP_IF_NEEDED(m1, m2, slot_expr, test_func) \
189+
do { \
190+
if (BINOP_IS_FORWARD(m1, m2, slot_expr, test_func) && \
191+
binop_should_defer((PyObject*)m1, (PyObject*)m2, 1)) { \
175192
Py_INCREF(Py_NotImplemented); \
176193
return Py_NotImplemented; \
177194
} \
@@ -187,7 +204,7 @@ binop_override_forward_binop_should_defer(PyObject *self, PyObject *other)
187204
*/
188205
#define RICHCMP_GIVE_UP_IF_NEEDED(m1, m2) \
189206
do { \
190-
if (binop_override_forward_binop_should_defer((PyObject*)m1, (PyObject*)m2)) { \
207+
if (binop_should_defer((PyObject*)m1, (PyObject*)m2, 0)) { \
191208
Py_INCREF(Py_NotImplemented); \
192209
return Py_NotImplemented; \
193210
} \

0 commit comments

Comments
 (0)
0