8000 Merge pull request #6000 from njsmith/inplace-matmul-error · githubmlai/numpy@23e10e1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 23e10e1

Browse files
committed
Merge pull request numpy#6000 from njsmith/inplace-matmul-error
BUG: Make a @= b error out
2 parents 38d6f09 + 1adcdf7 commit 23e10e1

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

numpy/core/src/multiarray/number.c

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,15 @@ array_matrix_multiply(PyArrayObject *m1, PyObject *m2)
405405
0, nb_matrix_multiply);
406406
return PyArray_GenericBinaryFunction(m1, m2, matmul);
407407
}
408+
409+
static PyObject *
410+
array_inplace_matrix_multiply(PyArrayObject *m1, PyObject *m2)
411+
{
412+
PyErr_SetString(PyExc_TypeError,
413+
"In-place matrix multiplication is not (yet) supported. "
414+
"Use 'a = a @ b' instead of 'a @= b'.");
415+
return NULL;
416+
}
408417
#endif
409418

410419
/* Determine if object is a scalar and if so, convert the object
@@ -1092,6 +1101,6 @@ NPY_NO_EXPORT PyNumberMethods array_as_number = {
10921101
(unaryfunc)array_index, /*nb_index */
10931102
#if PY_VERSION_HEX >= 0x03050000
10941103
(binaryfunc)array_matrix_multiply, /*nb_matrix_multiply*/
1095-
(binaryfunc)NULL, /*nb_inplacematrix_multiply*/
1104+
(binaryfunc)array_inplace_matrix_multiply, /*nb_inplace_matrix_multiply*/
10961105
#endif
10971106
};

numpy/core/tests/test_multiarray.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4322,6 +4322,19 @@ def __rmatmul__(self, other):
43224322
assert_equal(self.matmul(a, b), "A")
43234323
assert_equal(self.matmul(b, a), "A")
43244324

4325+
def test_matmul_inplace():
4326+
# It would be nice to support in-place matmul eventually, but for now
4327+
# we don't have a working implementation, so better just to error out
4328+
# and nudge people to writing "a = a @ b".
4329+
a = np.eye(3)
4330+
b = np.eye(3)
4331+
assert_raises(TypeError, a.__imatmul__, b)
4332+
import operator
4333+
assert_raises(TypeError, operator.imatmul, a, b)
4334+
# we avoid writing the token `exec` so as not to crash python 2's
4335+
# parser
4336+
exec_ = getattr(builtins, "exec")
4337+
assert_raises(TypeError, exec_, "a @= b", globals(), locals())
43254338

43264339
class TestInner(TestCase):
43274340

0 commit comments

Comments
 (0)
0