8000 ENH: Add support for the "@" matrix multiply operator to ndarray. · numpy/numpy@f981ef2 · GitHub
[go: up one dir, main page]

Skip to content

Commit f981ef2

Browse files
committed
ENH: Add support for the "@" matrix multiply operator to ndarray.
Also simplify the code for the ndarray.dot method.
1 parent 68304eb commit f981ef2

File tree

5 files changed

+168
-178
lines changed

5 files changed

+168
-178
lines changed

numpy/core/code_generators/numpy_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,6 @@
343343
# End 1.8 API
344344
# End 1.9 API
345345
'PyArray_CheckAnyScalarExact': (300, NonNull(1)),
346-
'PyArray_Matmul': (301,),
347346
# End 1.10 API
348347
}
349348

numpy/core/src/multiarray/methods.c

Lines changed: 24 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "item_selection.h"
1919
#include "conversion_utils.h"
2020
#include "shape.h"
21+
#include "multiarraymodule.h"
2122

2223
#include "methods.h"
2324

@@ -1996,58 +1997,6 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds)
19961997
}
19971998

19981999

1999-
static PyObject *
2000-
array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds)
2001-
{
2002-
static PyUFuncObject *cached_npy_dot = NULL;
2003-
int errval;
2004-
PyObject *override = NULL;
2005-
PyObject *a = (PyObject *)self, *b, *o = Py_None;
2006-
PyObject *newargs;
2007-
PyArrayObject *ret;
2008-
char* kwlist[] = {"b", "out", NULL };
2009-
2010-
2011-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &b, &o)) {
2012-
return NULL;
2013-
}
2014-
2015-
if (cached_npy_dot == NULL) {
2016-
PyObject *module = PyImport_ImportModule("numpy.core.multiarray");
2017-
cached_npy_dot = (PyUFuncObject*)PyDict_GetItemString(
2018-
PyModule_GetDict(module), "dot");
2019-
2020-
Py_INCREF(cached_npy_dot);
2021-
Py_DECREF(module);
2022-
}
2023-
2024-
if ((newargs = PyTuple_Pack(3, a, b, o)) == NULL) {
2025-
return NULL;
2026-
}
2027-
errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__",
2028-
newargs, NULL, &override, 2);
2029-
Py_DECREF(newargs);
2030-
2031-
if (errval) {
2032-
return NULL;
2033-
}
2034-
else if (override) {
2035-
return override;
2036-
}
2037-
2038-
if (o == Py_None) {
2039-
o = NULL;
2040-
}
2041-
if (o != NULL && !PyArray_Check(o)) {
2042-
PyErr_SetString(PyExc_TypeError,
2043-
"'out' must be an array");
2044-
return NULL;
2045-
}
2046-
ret = (PyArrayObject *)PyArray_MatrixProduct2(a, b, (PyArrayObject *)o);
2047-
return PyArray_Return(ret);
2048-
}
2049-
2050-
20512000
static PyObject *
20522001
array_any(PyArrayObject *self, PyObject *args, PyObject *kwds)
20532002
{
@@ -2317,18 +2266,31 @@ array_newbyteorder(PyArrayObject *self, PyObject *args)
23172266
}
23182267

23192268

2320-
/*static PyObject **/
2321-
/*array_matmul(PyArrayObject *self, PyObject *rhs)*/
2322-
/*{*/
2323-
/*return PyArray_Matmul(self, rhs, out=NULL);*/
2324-
/*}*/
2269+
static PyObject *
2270+
array_dot(PyObject *self, PyObject *args, PyObject *kwds)
2271+
{
2272+
static PyObject *newkwargs = NULL;
2273+
PyObject *a = (PyObject *)self, *b, *o = Py_None;
2274+
PyArrayObject *ret;
2275+
PyObject *newargs;
2276+
char* kwlist[] = {"b", "out", NULL };
23252277

23262278

2327-
/*static PyObject **/
2328-
/*array_rmatmul(PyArrayObject *self, PyObject *lhs)*/
2329-
/*{*/
2330-
/*return PyArray_Matmul(lhs, self, out=NULL);*/
2331-
/*}*/
2279+
if (newkwargs == NULL) {
2280+
if ((newkwargs = PyDict_New()) == NULL) {
2281+
return NULL;
2282+
}
2283+
}
2284+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &b, &o)) {
2285+
return NULL;
2286+
}
2287+
if ((newargs = PyTuple_Pack(3, a, b, o)) == NULL) {
2288+
return NULL;
2289+
}
2290+
ret = array_matrixproduct(NULL, newargs, newkwargs);
2291+
Py_DECREF(newargs);
2292+
return ret;
2293+
}
23322294

23332295

23342296
NPY_NO_EXPORT PyMethodDef array_methods[] = {
@@ -2357,14 +2319,6 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
23572319
(PyCFunction)array_deepcopy,
23582320
METH_VARARGS, NULL},
23592321

2360-
/* for '@' operator in Python >= 3.5 */
2361-
/*{"__matmul__",*/
2362-
/*(PyCFunction)array_matmul,*/
2363-
/*METH_O, NULL},*/
2364-
/*{"__rmatmul__",*/
2365-
/*(PyCFunction)array_rmatmul,*/
2366-
/*METH_O, NULL},*/
2367-
23682322
/* for Pickling */
23692323
{"__reduce__",
23702324
(PyCFunction) array_reduce,

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 105 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,109 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
6464
/* Only here for API compatibility */
6565
NPY_NO_EXPORT PyTypeObject PyBigArray_Type;
6666

67+
68+
/*
69+
* matmul
70+
*
71+
* Implements the protocol used by the '@' operator defined in PEP 364.
72+
* Not in the NUMPY API at this time, maybe later.
73+
*
74+
*
75+
* in1: Left hand side operand
76+
* in2: Right hand side operand
77+
* out: Either NULL, or an array into which the output should be placed.
78+
*
79+
* Returns NULL on error.
80+
*/
81+
static PyObject *
82+
PyArray_Matmul(PyObject *in1, PyObject *in2, PyArrayObject *out)
83+
{
84+
PyArrayObject *ap1, *ap2, *ret = NULL;
85+
NPY_ORDER order = NPY_KEEPORDER;
86+
NPY_CASTING casting = NPY_SAFE_CASTING;
87+
PyArray_Descr *dtype;
88+
int typenum;
89+
char *subscripts;
90+
PyArrayObject *ops[2];
91+
NPY_BEGIN_THREADS_DEF;
92+
93+
dtype = PyArray_DescrFromObject(in1, NULL);
94+
dtype = PyArray_DescrFromObject(in2, dtype);
95+
if (dtype == NULL) {
96+
PyErr_SetString(PyExc_ValueError,
97+
"Cannot find a common data type.");
98+
return NULL;
99+
}
100+
101+
Py_INCREF(dtype);
102+
ap1 = (PyArrayObject *)PyArray_FromAny(in1, dtype, 0, 0,
103+
NPY_ARRAY_ALIGNED, NULL);
104+
if (ap1 == NULL) {
105+
Py_DECREF(dtype);
106+
return NULL;
107+
}
108+
ap2 = (PyArrayObject *)PyArray_FromAny(in2, dtype, 0, 0,
109+
NPY_ARRAY_ALIGNED, NULL);
110+
if (ap2 == NULL) {
111+
Py_DECREF(ap1);
112+
return NULL;
113+
}
114+
115+
if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
116+
/* Scalars are rejected */
117+
PyErr_SetString(PyExc_ValueError,
118+
"Scalar operands are not allowed, use '*' instead");
119+
return NULL;
120+
}
121+
122+
typenum = dtype->type_num;
123+
#if defined(HAVE_CBLAS)
124+
if (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2 &&
125+
(NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum ||
126+
NPY_FLOAT == typenum || NPY_CFLOAT == typenum)) {
127+
return cblas_matrixproduct(typenum, ap1, ap2, out);
128+
}
129+
#endif
130+
131+
/*
132+
* Use einsum for the stacked cases. This is a quick implementation
133+
* to avoid setting up the proper iterators.
134+
*/
135+
ops[0] = ap1;
136+
ops[1] = ap2;
137+
if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 1) {
138+
/* vector vector */
139+
subscripts = "i, i";
140+
}
141+
else if (PyArray_NDIM(ap1) == 1) {
142+
/* vector matrix */
143+
subscripts = "i, ...ij";
144+
}
145+
else if (PyArray_NDIM(ap2) == 1) {
146+
/* matrix vector */
147+
subscripts = "...i, i";
148+
}
149+
else {
150+
/* matrix * matrix */
151+
subscripts = "...ij, ...jk";
152+
}
153+
154+
ret = PyArray_EinsteinSum(subscripts, 2, ops, dtype, order, casting, out);
155+
if (ret == NULL) {
156+
goto fail;
157+
}
158+
Py_DECREF(ap1);
159+
Py_DECREF(ap2);
160+
return (PyObject *)ret;
161+
162+
fail:
163+
Py_XDECREF(ap1);
164+
Py_XDECREF(ap2);
165+
Py_XDECREF(ret);
166+
return NULL;
167+
}
168+
169+
67170
/*NUMPY_API
68171
* Get Priority from object
69172
*/
@@ -1092,107 +1195,6 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
10921195
}
10931196

10941197

1095-
/*NUMPY_API
1096-
* matmul
1097-
*
1098-
* Implements the protocol used by the '@' operator defined in PEP 364.
1099-
*
1100-
*
1101-
* in1: Left hand side operand
1102-
* in2: Right hand side operand
1103-
* out: Either NULL, or an array into which the output should be placed.
1104-
*
1105-
* Returns NULL on error.
1106-
*/
1107-
NPY_NO_EXPORT PyObject *
1108-
PyArray_Matmul(PyObject *in1, PyObject *in2, PyArrayObject *out)
1109-
{
1110-
PyArrayObject *ap1, *ap2, *ret = NULL;
1111-
NPY_ORDER order = NPY_KEEPORDER;
1112-
NPY_CASTING casting = NPY_SAFE_CASTING;
1113-
PyArray_Descr *dtype;
1114-
int typenum;
1115-
char *subscripts;
1116-
PyArrayObject *ops[2];
1117-
NPY_BEGIN_THREADS_DEF;
1118-
1119-
dtype = PyArray_DescrFromObject(in1, NULL);
1120-
dtype = PyArray_DescrFromObject(in2, dtype);
1121-
if (dtype == NULL) {
1122-
PyErr_SetString(PyExc_ValueError,
1123-
"Cannot find a common data type.");
1124-
return NULL;
1125-
}
1126-
1127-
Py_INCREF(dtype);
1128-
ap1 = (PyArrayObject *)PyArray_FromAny(in1, dtype, 0, 0,
1129-
NPY_ARRAY_ALIGNED, NULL);
1130-
if (ap1 == NULL) {
1131-
Py_DECREF(dtype);
1132-
return NULL;
1133-
}
1134-
ap2 = (PyArrayObject *)PyArray_FromAny(in2, dtype, 0, 0,
1135-
NPY_ARRAY_ALIGNED, NULL);
1136-
if (ap2 == NULL) {
1137-
Py_DECREF(ap1);
1138-
return NULL;
1139-
}
1140-
1141-
if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
1142-
/* Scalars are rejected */
1143-
PyErr_SetString(PyExc_TypeError,
1144-
"Scalar operands are not allowed, use '*' instead");
1145-
return NULL;
1146-
}
1147-
1148-
typenum = dtype->type_num;
1149-
#if defined(HAVE_CBLAS)
1150-
if (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2 &&
1151-
(NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum ||
1152-
NPY_FLOAT == typenum || NPY_CFLOAT == typenum)) {
1153-
return cblas_matrixproduct(typenum, ap1, ap2, out);
1154-
}
1155-
#endif
1156-
1157-
/*
1158-
* Use einsum for the stacked cases. This is a quick implementation
1159-
* to avoid setting up the proper iterators.
1160-
*/
1161-
ops[0] = ap1;
1162-
ops[1] = ap2;
1163-
if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 1) {
1164-
/* vector vector */
1165-
subscripts = "i, i";
1166-
}
1167-
else if (PyArray_NDIM(ap1) == 1) {
1168-
/* vector matrix */
1169-
subscripts = "i, ...ij";
1170-
}
1171-
else if (PyArray_NDIM(ap2) == 1) {
1172-
/* matrix vector */
1173-
subscripts = "...i, i";
1174-
}
1175-
else {
1176-
/* matrix * matrix */
1177-
subscripts = "...ij, ...jk";
1178-
}
1179-
1180-
ret = PyArray_EinsteinSum(subscripts, 2, ops, dtype, order, casting, out);
1181-
if (ret == NULL) {
1182-
goto fail;
1183-
}
1184-
Py_DECREF(ap1);
1185-
Py_DECREF(ap2);
1186-
return (PyObject *)ret;
1187-
1188-
fail:
1189-
Py_XDECREF(ap1);
1190-
Py_XDECREF(ap2);
1191-
Py_XDECREF(ret);
1192-
return NULL;
1193-
}
1194-
1195-
11961198
/*NUMPY_API
11971199
* Copy and Transpose
11981200
*
@@ -2318,7 +2320,7 @@ array_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
23182320
return PyArray_Return((PyArrayObject *)PyArray_InnerProduct(a0, b0));
23192321
}
23202322

2321-
static PyObject *
2323+
NPY_NO_EXPORT PyObject *
23222324
array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwds)
23232325
{
23242326
static PyUFuncObject *cached_npy_dot = NULL;
@@ -2471,7 +2473,7 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
24712473
}
24722474

24732475

2474-
static PyObject *
2476+
NPY_NO_EXPORT PyObject *
24752477
array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds)
24762478
{
24772479
static PyUFuncObject *cached_npy_matmul = NULL;

numpy/core/src/multiarray/multiarraymodule.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,10 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_copy;
1212
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_dtype;
1313
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_ndmin;
1414

15+
NPY_NO_EXPORT PyObject *
16+
array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds);
17+
18+
NPY_NO_EXPORT PyObject *
19+
array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwds);
20+
1521
#endif

0 commit comments

Comments
 (0)
0