8000 ENH: Add __matmul__ and __rmatmul__ methods to ndarray. · numpy/numpy@0c9e8fd · GitHub
[go: up one dir, main page]

Skip to content

Commit 0c9e8fd

Browse files
committed
ENH: Add __matmul__ and __rmatmul__ methods to ndarray.
Also simplify the code for the ndarray.dot method.
1 parent 5ae0c1a commit 0c9e8fd

File tree

3 files changed

+80
-70
lines changed

3 files changed

+80
-70
lines changed

numpy/core/src/multiarray/methods.c

Lines changed: 72 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "item_selection.h"
1919
#include "conversion_utils.h"
2020
#include "shape.h"
21+
#include "multiarraymodule.h"
2122

2223
#include "methods.h"
2324

@@ -1996,58 +1997,6 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds)
19961997
}
19971998

19981999

1999-
static PyObject *
2000-
array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds)
2001-
{
2002-
static PyUFuncObject *cached_npy_dot = NULL;
2003-
int errval;
2004-
PyObject *override = NULL;
2005-
PyObject *a = (PyObject *)self, *b, *o = Py_None;
2006-
PyObject *newargs;
2007-
PyArrayObject *ret;
2008-
char* kwlist[] = {"b", "out", NULL };
2009-
2010-
2011-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &b, &o)) {
2012-
return NULL;
2013-
}
2014-
2015-
if (cached_npy_dot == NULL) {
2016-
PyObject *module = PyImport_ImportModule("numpy.core.multiarray");
2017-
cached_npy_dot = (PyUFuncObject*)PyDict_GetItemString(
2018-
PyModule_GetDict(module), "dot");
2019-
2020-
Py_INCREF(cached_npy_dot);
2021-
Py_DECREF(module);
2022-
}
2023-
2024-
if ((newargs = PyTuple_Pack(3, a, b, o)) == NULL) {
2025-
return NULL;
2026-
}
2027-
errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__",
2028-
newargs, NULL, &override, 2);
2029-
Py_DECREF(newargs);
2030-
2031-
if (errval) {
2032-
return NULL;
2033-
}
2034-
else if (override) {
2035-
return override;
2036-
}
2037-
2038-
if (o == Py_None) {
2039-
o = NULL;
2040-
}
2041-
if (o != NULL && !PyArray_Check(o)) {
2042-
PyErr_SetString(PyExc_TypeError,
2043-
"'out' must be an array");
2044-
return NULL;
2045-
}
2046-
ret = (PyArrayObject *)PyArray_MatrixProduct2(a, b, (PyArrayObject *)o);
2047-
return PyArray_Return(ret);
2048-
}
2049-
2050-
20512000
static PyObject *
20522001
array_any(PyArrayObject *self, PyObject *args, PyObject *kwds)
20532002
{
@@ -2317,18 +2266,73 @@ array_newbyteorder(PyArrayObject *self, PyObject *args)
23172266
}
23182267

23192268

2320-
/*static PyObject **/
2321-
/*array_matmul(PyArrayObject *self, PyObject *rhs)*/
2322-
/*{*/
2323-
/*return PyArray_Matmul(self, rhs, out=NULL);*/
2324-
/*}*/
2269+
static PyObject *
2270+
array_dot(PyObject *self, PyObject *args, PyObject *kwds)
2271+
{
2272+
static PyObject *newkwargs = NULL;
2273+
PyObject *a = (PyObject *)self, *b, *o = Py_None;
2274+
PyArrayObject *ret;
2275+
PyObject *newargs;
2276+
char* kwlist[] = {"b", "out", NULL };
2277+
2278+
2279+
if (newkwargs == NULL) {
2280+
if ((newkwargs = PyDict_New()) == NULL) {
2281+
return NULL;
2282+
}
2283+
}
2284+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &b, &o)) {
2285+
return NULL;
2286+
}
2287+
if ((newargs = PyTuple_Pack(3, a, b, o)) == NULL) {
2288+
return NULL;
2289+
}
2290+
ret = array_matrixproduct(NULL, newargs, newkwargs);
2291+
Py_DECREF(newargs);
2292+
return ret;
2293+
}
2294+
2295+
2296+
static PyObject *
2297+
array__matmul__(PyObject *self, PyObject *rhs)
2298+
{
2299+
static PyObject *newkwargs = NULL;
2300+
PyObject *newargs;
2301+
PyArrayObject *ret;
2302+
2303+
if (newkwargs == NULL) {
2304+
if ((newkwargs = PyDict_New()) == NULL) {
2305+
return NULL;
2306+
}
2307+
}
2308+
if ((newargs = PyTuple_Pack(2, self, rhs)) == NULL) {
2309+
return NULL;
2310+
}
2311+
ret = array_matmul(NULL, newargs, newkwargs);
2312+
Py_DECREF(newargs);
2313+
return ret;
2314+
}
2315+
23252316

2317+
static PyObject *
2318+
array__rmatmul__(PyObject *self, PyObject *lhs)
2319+
{
2320+
static PyObject *newkwargs = NULL;
2321+
PyObject *newargs;
2322+
PyArrayObject *ret;
23262323

2327-
/*static PyObject **/
2328-
/*array_rmatmul(PyArrayObject *self, PyObject *lhs)*/
2329-
/*{*/
2330-
/*return PyArray_Matmul(lhs, self, out=NULL);*/
2331-
/*}*/
2324+
if (newkwargs == NULL) {
2325+
if ((newkwargs = PyDict_New()) == NULL) {
2326+
return NULL;
2327+
}
2328+
}
2329+
if ((newargs = PyTuple_Pack(2, lhs, self)) == NULL) {
2330+
return NULL;
2331+
}
2332+
ret = array_matmul(NULL, newargs, newkwargs);
2333+
Py_DECREF(newargs);
2334+
return ret;
2335+
}
23322336

23332337

23342338
NPY_NO_EXPORT PyMethodDef array_methods[] = {
@@ -2358,12 +2362,12 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
23582362
METH_VARARGS, NULL},
23592363

23602364
/* for '@' operator in Python >= 3.5 */
2361-
/*{"__matmul__",*/
2362-
/*(PyCFunction)array_matmul,*/
2363-
/*METH_O, NULL},*/
2364-
/*{"__rmatmul__",*/
2365-
/*(PyCFunction)array_rmatmul,*/
2366-
/*METH_O, NULL},*/
2365+
{"__matmul__",
2366+
(PyCFunction)array__matmul__,
2367+
METH_O, NULL},
2368+
{"__rmatmul__",
2369+
(PyCFunction)array__rmatmul__,
2370+
METH_O, NULL},
23672371

23682372
/* for Pickling */
23692373
{"__reduce__",

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,7 +2318,7 @@ array_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
23182318
return PyArray_Return((PyArrayObject *)PyArray_InnerProduct(a0, b0));
23192319
}
23202320

2321-
static PyObject *
2321+
NPY_NO_EXPORT PyObject *
23222322
array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwds)
23232323
{
23242324
static PyUFuncObject *cached_npy_dot = NULL;
@@ -2471,7 +2471,7 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
24712471
}
24722472

24732473

2474-
static PyObject *
2474+
NPY_NO_EXPORT PyObject *
24752475
array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds)
24762476
{
24772477
static PyUFuncObject *cached_npy_matmul = NULL;

numpy/core/src/multiarray/multiarraymodule.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,10 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_copy;
1212
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_dtype;
1313
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_ndmin;
1414

15+
NPY_NO_EXPORT PyObject *
16+
array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds);
17+
18+
NPY_NO_EXPORT PyObject *
19+
array_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwds);
20+
1521
#endif

0 commit comments

Comments
 (0)
0