10000 ENH: Add support for the "@" matrix multiply operator to ndarray. · numpy/numpy@f347bae · GitHub
[go: up one dir, main page]

Skip to content

Commit f347bae

Browse files
committed
ENH: Add support for the "@" matrix multiply operator to ndarray.
Also simplify the code for the ndarray.dot method.
1 parent 68304eb commit f347bae

File tree

4 files changed

+66
-77
lines changed

4 files changed

+66
-77
lines changed

numpy/core/src/multiarray/methods.c

Lines changed: 24 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "item_selection.h"
1919
#include "conversion_utils.h"
2020
#include "shape.h"
21+
#include "multiarraymodule.h"
2122

2223
#include "methods.h"
2324

@@ -1996,58 +1997,6 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds)
19961997
}
19971998

19981999

1999-
static PyObject *
2000-
array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds)
2001-
{
2002-
static PyUFuncObject *cached_npy_dot = NULL;
2003-
int errval;
2004-
PyObject *override = NULL;
2005-
PyObject *a = (PyObject *)self, *b, *o = Py_None;
2006-
PyObject *newargs;
2007-
PyArrayObject *ret;
2008-
char* kwlist[] = {"b", "out", NULL };
2009-
2010-
2011-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &b, &o)) {
2012-
return NULL;
2013-
}
2014-
2015-
if (cached_npy_dot == NULL) {
2016-
PyObject *module = PyImport_ImportModule("numpy.core.multiarray");
2017-
cached_npy_dot = (PyUFuncObject*)PyDict_GetItemString(
2018-
PyModule_GetDict(module), "dot");
2019-
2020-
Py_INCREF(cached_npy_dot);
2021-
Py_DECREF(module);
2022-
}
2023-
2024-
if ((newargs = PyTuple_Pack(3, a, b, o)) == NULL) {
2025-
return NULL;
2026-
}
2027-
errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__",
2028-
newargs, NULL, &override, 2);
2029-
Py_DECREF(newargs);
2030-
2031-
if (errval) {
2032-
return NULL;
2033-
}
2034-
else if (override) {
2035-
return override;
2036-
}
2037-
2038-
if (o == Py_None) {
2039-
o = NULL;
2040-
}
2041-
if (o != NULL && !PyArray_Check(o)) {
2042-
PyErr_SetString(PyExc_TypeError,
2043-
"'out' must be an array");
2044-
return NULL;
2045-
}
2046-
ret = (PyArrayObject *)PyArray_MatrixProduct2(a, b, (PyArrayObject *)o);
2047-
return PyArray_Return(ret);
2048-
}
2049-
2050-
20512000
static PyObject *
20522001
array_any(PyArrayObject *self, PyObject *args, PyObject *kwds)
20532002
{
@@ -2317,18 +2266,31 @@ array_newbyteorder(PyArrayObject *self, PyObject *args)
23172266
}
23182267

23192268

2320-
/*static PyObject **/
2321-
/*array_matmul(PyArrayObject *self, PyObject *rhs)*/
2322-
/*{*/
2323-
/*return PyArray_Matmul(self, rhs, out=NULL);*/
2324-
/*}*/
2269+
static PyObject *
2270+
array_dot(PyObject *self, PyObject *args, PyObject *kwds)
2271+
{
2272+
static PyObject *newkwargs = NULL;
2273+
PyObject *a = (PyObject *)self, *b, *o = Py_None;
2274+
PyArrayObject *ret;
2275+
PyObject *newargs;
2276+
char* kwlist[] = {"b", "out", NULL };
23252277

23262278

2327-
/*static PyObject **/
2328-
/*array_rmatmul(PyArrayObject *self, PyObject *lhs)*/
2329-
/*{*/
2330-
/*return PyArray_Matmul(lhs, self, out=NULL);*/
2331-
/*}*/
2279+
if (newkwargs == NULL) {
2280+
if ((newkwargs = PyDict_New()) == NULL) {
2281+
return NULL;
2282+
}
2283+
}
2284+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &b, &o)) {
2285+
return NULL;
2286+
}
2287+
if ((newargs = PyTuple_Pack(3, a, b, o)) == NULL) {
2288+
return NULL;
2289+
}
2290+
ret = array_matrixproduct(NULL, newargs, newkwargs);
2291+
Py_DECREF(newargs);
2292+
return ret;
2293+
}
23322294

23332295

23342296
NPY_NO_EXPORT PyMethodDef array_methods[] = {
@@ -2357,14 +2319,6 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
23572319
(PyCFunction)array_deepcopy,
23582320
METH_VARARGS, NULL},
23592321

2360-
/* for '@' operator in Python >= 3.5 */
2361-
/*{"__matmul__",*/
2362-
/*(PyCFunction)array_matmul,*/
2363-
/*METH_O, NULL},*/
2364-
/*{"__rmatmul__",*/
2365-
/*(PyCFunction)array_rmatmul,*/
2366-
/*METH_O, NULL},*/
2367-
23682322
/* for Pickling */
23692323
{"__reduce__",
23702324
(PyCFunction) array_reduce,

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,7 @@ PyArray_Matmul(PyObject *in1, PyObject *in2, PyArrayObject *out)
11401140

11411141
if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
11421142
/* Scalars are rejected */
1143-
PyErr_SetString(PyExc_TypeError,
1143+
PyErr_SetString(PyExc_ValueError,
11441144
"Scalar operands are not allowed, use '*' instead");
11451145
return NULL;
11461146
}
@@ -2318,7 +2318,7 @@ array_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
23182318
return PyArray_Return((PyArrayObject *)PyArray_InnerProduct(a0, b0));
23192319
}
23202320

2321-
static PyObject *
2321+
NPY_NO_EXPORT PyObject *
23222322
array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwds)
23232323
{
23242324
static PyUFuncObject *cached_npy_dot = NULL;
@@ -2471,7 +2471,7 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
24712471
}
24722472

24732473

2474-
static PyObject *
2474+
NPY_NO_EXPORT PyObject *
24752475
array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds)
24762476
{
24772477
static PyUFuncObject *cached_npy_matmul = NULL;

numpy/core/src/multiarray/multiarraymodule.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,10 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_copy;
1212
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_dtype;
1313
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_ndmin;
1414

15+
NPY_NO_EXPORT PyObject *
16+
array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds);
17+
18+
NPY_NO_EXPORT PyObject *
19+
array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwds);
20+
1521
#endif

numpy/core/src/multiarray/number.c

Lines changed: 33 additions & 4 deletions
E377
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
#include "numpy/arrayobject.h"
99

1010
#include "npy_config.h"
11-
1211
#include "npy_pycompat.h"
13-
14-
#include "number.h"
1512
#include "common.h"
13+
#include "multiarraymodule.h"
14+
#include "number.h"
1615

1716
/*************************************************************************
1817
**************** Implement Number Protocol ****************************
@@ -386,6 +385,31 @@ array_remainder(PyArrayObject *m1, PyObject *m2)
386385
return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder);
387386
}
388387

388+
389+
#if PY_VERSION_HEX >= 0x03050000
390+
/* Need this to be version dependent on account of the slot check */
391+
static PyObject *
392+
array_matrix_multiply(PyArrayObject *m1, PyObject *m2)
393+
{
394+
static PyObject *newkwargs = NULL;
395+
PyObject *newargs;
396+
PyObject *ret;
397+
398+
if (newkwargs == NULL) {
399+
if ((newkwargs = PyDict_New()) == NULL) {
400+
return NULL;
401+
}
402+
}
403+
GIVE_UP_IF_HAS_RIGHT_BINOP(m1, m2, "__matmul__", "__rmatmul__", 0, nb_matrix_multiply);
404+
if ((newargs = PyTuple_Pack(2, m1, m2)) == NULL) {
405+
return NULL;
406+
}
407+
ret = array_matmul(NULL, newargs, newkwargs);
408+
Py_DECREF(newargs);
409+
return ret;
410+
}
411+
#endif
412+
389413
/* Determine if object is a scalar and if so, convert the object
390414
* to a double and place it in the out_exponent argument
391415
* and return the "scalar kind" as a result. If the object is
@@ -723,6 +747,7 @@ array_inplace_true_divide(PyArrayObject *m1, PyObject *m2)
723747
n_ops.true_divide);
724748
}
725749

750+
726751
static int
727752
_array_nonzero(PyArrayObject *mp)
728753
{
@@ -1066,5 +1091,9 @@ NPY_NO_EXPORT PyNumberMethods array_as_number = {
10661091
(binaryfunc)array_true_divide, /*nb_true_divide*/
10671092
(binaryfunc)array_inplace_floor_divide, /*nb_inplace_floor_divide*/
10681093
(binaryfunc)array_inplace_true_divide, /*nb_inplace_true_divide*/
1069-
(unaryfunc)array_index, /* nb_index */
1094+
(unaryfunc)array_index, /*nb_index */
1095+
#if PY_VERSION_HEX >= 0x03050000
1096+
(binaryfunc)array_matrix_multiply, /*nb_matrix_multiply*/
1097+
(binaryfunc)NULL, /*nb_inplacematrix_multiply*/
1098+
#endif
10701099
};

0 commit comments

Comments
 (0)
0