8000 ENH: Add a matmul function to multiarray. · numpy/numpy@68304eb · GitHub
[go: up one dir, main page]

Skip to content

Commit 68304eb

Browse files
committed
ENH: Add a matmul function to multiarray.
This is the functional counterpart of the '@' operator that will be available in Python 3.5 with the addition of an out keyword. It operates like the dot function except that - scalar multiplication is not allowed. - multiplication of arrays with more than 2 dimensions broadcasts. The last means that when arrays have more than 2 dimensions they are treated as stacks of matrices and those stacks are broadcast against each other unlike the current behavior of dot that does an outer product. Like dot, matmul is aware of `__numpy_ufunc__` and can be overridden. The current version of the function uses einsum when cblas does not work, hence object arrays are not yet supported.
1 parent 0174f2a commit 68304eb

File tree

4 files changed

+190
-16
lines changed

4 files changed

+190
-16
lines changed

numpy/core/code_generators/numpy_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@
343343
# End 1.8 API
344344
# End 1.9 API
345345
'PyArray_CheckAnyScalarExact': (300, NonNull(1)),
346+
'PyArray_Matmul': (301,),
346347
# End 1.10 API
347348
}
348349

numpy/core/numeric.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
'Inf', 'inf', 'infty', 'Infinity',
4444
'nan', 'NaN', 'False_', 'True_', 'bitwise_not',
4545
'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS',
46-
'ComplexWarning', 'may_share_memory', 'full', 'full_like']
46+
'ComplexWarning', 'may_share_memory', 'full', 'full_like',
47+
'matmul']
4748

4849
if sys.version_info[0] < 3:
4950
__all__.extend(['getbuffer', 'newbuffer'])
@@ -390,6 +391,11 @@ def extend_all(module):
390391
compare_chararrays = multiarray.compare_chararrays
391392
putmask = multiarray.putmask
392393
einsum = multiarray.einsum
394+
dot = multiarray.dot
395+
inner = multiarray.inner
396+
vdot = multiarray.vdot
397+
matmul = multiarray.matmul
398+
393399

394400
def asarray(a, dtype=None, order=None):
395401
"""
@@ -1081,11 +1087,6 @@ def outer(a, b, out=None):
10811087
b = asarray(b)
10821088
return multiply(a.ravel()[:, newaxis], b.ravel()[newaxis,:], out)
10831089

1084-
# try to import blas optimized dot if available
1085-
envbak = os.environ.copy()
1086-
dot = multiarray.dot
1087-
inner = multiarray.inner
1088-
vdot = multiarray.vdot
10891090

10901091
def alterdot():
10911092
"""

numpy/core/src/multiarray/methods.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2316,6 +2316,21 @@ array_newbyteorder(PyArrayObject *self, PyObject *args)
23162316

23172317
}
23182318

2319+
2320+
/*static PyObject **/
2321+
/*array_matmul(PyArrayObject *self, PyObject *rhs)*/
2322+
/*{*/
2323+
/*return PyArray_Matmul(self, rhs, out=NULL);*/
2324+
/*}*/
2325+
2326+
2327+
/*static PyObject **/
2328+
/*array_rmatmul(PyArrayObject *self, PyObject *lhs)*/
2329+
/*{*/
2330+
/*return PyArray_Matmul(lhs, self, out=NULL);*/
2331+
/*}*/
2332+
2333+
23192334
NPY_NO_EXPORT PyMethodDef array_methods[] = {
23202335

23212336
/* for subtypes */
@@ -2342,6 +2357,14 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
23422357
(PyCFunction)array_deepcopy,
23432358
METH_VARARGS, NULL},
23442359

2360+
/* for '@' operator in Python >= 3.5 */
2361+
/*{"__matmul__",*/
2362+
/*(PyCFunction)array_matmul,*/
2363+
/*METH_O, NULL},*/
2364+
/*{"__rmatmul__",*/
2365+
/*(PyCFunction)array_rmatmul,*/
2366+
/*METH_O, NULL},*/
2367+
23452368
/* for Pickling */
23462369
{"__reduce__",
23472370
(PyCFunction) array_reduce,

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 159 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ PyArray_AsCArray(PyObject **op, void *ptr, npy_intp *dims, int nd,
239239
*op = (PyObject *)ap;
240240
return 0;
241241

242-
fail:
242+
fail:
243243
PyErr_SetString(PyExc_MemoryError, "no memory");
244244
return -1;
245245
}
@@ -930,7 +930,7 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2)
930930
Py_DECREF(ap2);
931931
return (PyObject *)ret;
932932

933-
fail:
933+
fail:
934934
Py_XDECREF(ap1);
935935
Py_XDECREF(ap2);
936936
Py_XDECREF(ret);
@@ -1049,7 +1049,8 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
10491049
goto fail;
10501050
}
10511051

1052-
op = PyArray_DATA(ret); os = PyArray_DESCR(ret)->elsize;
1052+
op = PyArray_DATA(ret);
1053+
os = PyArray_DESCR(ret)->elsize;
10531054
axis = PyArray_NDIM(ap1)-1;
10541055
it1 = (PyArrayIterObject *)
10551056
PyArray_IterAllButAxis((PyObject *)ap1, &axis);
@@ -1083,7 +1084,108 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
10831084
Py_DECREF(ap2);
10841085
return (PyObject *)ret;
10851086

1086-
fail:
1087+
fail:
1088+
Py_XDECREF(ap1);
1089+
Py_XDECREF(ap2);
1090+
Py_XDECREF(ret);
1091+
return NULL;
1092+
}
1093+
1094+
1095+
/*NUMPY_API
1096+
* matmul
1097+
*
1098+
* Implements the protocol used by the '@' operator defined in PEP 364.
1099+
*
1100+
*
1101+
* in1: Left hand side operand
1102+
* in2: Right hand side operand
1103+
* out: Either NULL, or an array into which the output should be placed.
1104+
*
1105+
* Returns NULL on error.
1106+
*/
1107+
NPY_NO_EXPORT PyObject *
1108+
PyArray_Matmul(PyObject *in1, PyObject *in2, PyArrayObject *out)
1109+
{
1110+
PyArrayObject *ap1, *ap2, *ret = NULL;
1111+
NPY_ORDER order = NPY_KEEPORDER;
1112+
NPY_CASTING casting = NPY_SAFE_CASTING;
1113+
PyArray_Descr *dtype;
1114+
int typenum;
1115+
char *subscripts;
1116+
PyArrayObject *ops[2];
1117+
NPY_BEGIN_THREADS_DEF;
1118+
1119+
dtype = PyArray_DescrFromObject(in1, NULL);
1120+
dtype = PyArray_DescrFromObject(in2, dtype);
1121+
if (dtype == NULL) {
1122+
PyErr_SetString(PyExc_ValueError,
1123+
"Cannot find a common data type.");
1124+
return NULL;
1125+
}
1126+
1127+
Py_INCREF(dtype);
1128+
ap1 = (PyArrayObject *)PyArray_FromAny(in1, dtype, 0, 0,
1129+
NPY_ARRAY_ALIGNED, NULL);
1130+
if (ap1 == NULL) {
1131+
Py_DECREF(dtype);
1132+
return NULL;
1133+
}
1134+
ap2 = (PyArrayObject *)PyArray_FromAny(in2, dtype, 0, 0,
1135+
NPY_ARRAY_ALIGNED, NULL);
1136+
if (ap2 == NULL) {
1137+
Py_DECREF(ap1);
1138+
return NULL;
1139+
}
1140+
1141+
if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
1142+
/* Scalars are rejected */
1143+
PyErr_SetString(PyExc_TypeError,
1144+
"Scalar operands are not allowed, use '*' instead");
1145+
return NULL;
1146+
}
1147+
1148+
typenum = dtype->type_num;
1149+
#if defined(HAVE_CBLAS)
1150+
if (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2 &&
1151+
(NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum ||
1152+
NPY_FLOAT == typenum || NPY_CFLOAT == typenum)) {
1153+
return cblas_matrixproduct(typenum, ap1, ap2, out);
1154+
}
1155+
#endif
1156+
1157+
/*
1158+
* Use einsum for the stacked cases. This is a quick implementation
1159+
* to avoid setting up the proper iterators.
1160+
*/
1161+
ops[0] = ap1;
1162+
ops[1] = ap2;
1163+
if (PyArray_NDIM(ap1) == 1 && PyArray_NDIM(ap2) == 1) {
1164+
/* vector vector */
1165+
subscripts = "i, i";
1166+
}
1167+
else if (PyArray_NDIM(ap1) == 1) {
1168+
/* vector matrix */
1169+
subscripts = "i, ...ij";
1170+
}
1171+
else if (PyArray_NDIM(ap2) == 1) {
1172+
/* matrix vector */
1173+
subscripts = "...i, i";
1174+
}
1175+
else {
1176+
/* matrix * matrix */
1177+
subscripts = "...ij, ...jk";
1178+
}
1179+
1180+
ret = PyArray_EinsteinSum(subscripts, 2, ops, dtype, order, casting, out);
1181+
if (ret == NULL) {
1182+
goto fail;
1183+
}
1184+
Py_DECREF(ap1);
1185+
Py_DECREF(ap2);
1186+
return (PyObject *)ret;
1187+
1188+
fail:
10871189
Py_XDECREF(ap1);
10881190
Py_XDECREF(ap2);
10891191
Py_XDECREF(ret);
@@ -1844,7 +1946,7 @@ array_copyto(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
18441946

18451947
Py_RETURN_NONE;
18461948

1847-
fail:
1949+
fail:
18481950
Py_XDECREF(src);
18491951
Py_XDECREF(wheremask);
18501952
return NULL;
@@ -1887,7 +1989,7 @@ array_empty(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
18871989
PyDimMem_FREE(shape.ptr);
18881990
return (PyObject *)ret;
18891991

1890-
fail:
1992+
fail:
18911993
Py_XDECREF(typecode);
18921994
PyDimMem_FREE(shape.ptr);
18931995
return NULL;
@@ -1918,7 +2020,7 @@ array_empty_like(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
19182020

19192021
return (PyObject *)ret;
19202022

1921-
fail:
2023+
fail:
19222024
Py_XDECREF(prototype);
19232025
Py_XDECREF(dtype);
19242026
return NULL;
@@ -2041,7 +2143,7 @@ array_zeros(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
20412143
PyDimMem_FREE(shape.ptr);
20422144
return (PyObject *)ret;
20432145

2044-
fail:
2146+
fail:
20452147
Py_XDECREF(typecode);
20462148
PyDimMem_FREE(shape.ptr);
20472149
return (PyObject *)ret;
@@ -2369,6 +2471,50 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
23692471
}
23702472

23712473

2474+
static PyObject *
2475+
array_matmul(PyObject *NPY_UNUSED(m), PyObject *args, PyObject* kwds)
2476+
{
2477+
static PyUFuncObject *cached_npy_matmul = NULL;
2478+
int errval;
2479+
PyObject *override = NULL;
2480+
PyObject *v, *a, *o = NULL;
2481+
char* kwlist[] = {"a", "b", "out", NULL };
2482+
2483+
if (cached_npy_matmul == NULL) {
2484+
PyObject *module, *dict, *matmul;
2485+
2486+
module = PyImport_ImportModule("numpy.core.multiarray");
2487+
dict = PyModule_GetDict(module);
2488+
matmul = PyDict_GetItemString(dict, "matmul");
2489+
cached_npy_matmul = (PyUFuncObject*)matmul;
2490+
Py_INCREF(cached_npy_matmul);
2491+
Py_DECREF(module);
2492+
}
2493+
2494+
errval = PyUFunc_CheckOverride(cached_npy_matmul, "__call__",
2495+
args, kwds, &override, 2);
2496+
if (errval) {
2497+
return NULL;
2498+
}
2499+
else if (override) {
2500+
return override;
2501+
}
2502+
2503+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O", kwlist, &a, &v, &o)) {
2504+
return NULL;
2505+
}
2506+
if (o == Py_None) {
2507+
o = NULL;
2508+
}
2509+
if (o != NULL && !PyArray_Check(o)) {
2510+
PyErr_SetString(PyExc_TypeError,
2511+
"'out' must be an array");
2512+
return NULL;
2513+
}
2514+
return PyArray_Matmul(a, v, (PyArrayObject *)o);
2515+
}
2516+
2517+
23722518
static int
23732519
einsum_sub_op_from_str(PyObject *args, PyObject **str_obj, char **subscripts,
23742520
PyArrayObject **op)
@@ -2862,7 +3008,7 @@ array__reconstruct(PyObject *NPY_UNUSED(dummy), PyObject *args)
28623008

28633009
return ret;
28643010

2865-
fail:
3011+
fail:
28663012
evil_global_disable_warn_O4O8_flag = 0;
28673013

28683014
Py_XDECREF(dtype);
@@ -3090,7 +3236,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
30903236
return ret;
30913237
}
30923238

3093-
fail:
3239+
fail:
30943240
Py_DECREF(arr);
30953241
Py_XDECREF(ax);
30963242
Py_XDECREF(ay);
@@ -3936,6 +4082,9 @@ static struct PyMethodDef array_module_methods[] = {
39364082
{"vdot",
39374083
(PyCFunction)array_vdot,
39384084
METH_VARARGS | METH_KEYWORDS, NULL},
4085+
{"matmul",
4086+
(PyCFunction)array_matmul,
4087+
METH_VARARGS | METH_KEYWORDS, NULL},
39394088
{"einsum",
39404089
(PyCFunction)array_einsum,
39414090
METH_VARARGS|METH_KEYWORDS, NULL},

0 commit comments

Comments
 (0)
0