@@ -239,7 +239,7 @@ PyArray_AsCArray(PyObject **op, void *ptr, npy_intp *dims, int nd,
239
239
* op = (PyObject * )ap ;
240
240
return 0 ;
241
241
242
- fail :
242
+ fail :
243
243
PyErr_SetString (PyExc_MemoryError , "no memory" );
244
244
return -1 ;
245
245
}
@@ -930,7 +930,7 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2)
930
930
Py_DECREF (ap2 );
931
931
return (PyObject * )ret ;
932
932
933
- fail :
933
+ fail :
934
934
Py_XDECREF (ap1 );
935
935
Py_XDECREF (ap2 );
936
936
Py_XDECREF (ret );
@@ -1049,7 +1049,8 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
1049
1049
goto fail ;
1050
1050
}
1051
1051
1052
- op = PyArray_DATA (ret ); os = PyArray_DESCR (ret )-> elsize ;
1052
+ op = PyArray_DATA (ret );
1053
+ os = PyArray_DESCR (ret )-> elsize ;
1053
1054
axis = PyArray_NDIM (ap1 )- 1 ;
1054
1055
it1 = (PyArrayIterObject * )
1055
1056
PyArray_IterAllButAxis ((PyObject * )ap1 , & axis );
@@ -1083,7 +1084,108 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
1083
1084
Py_DECREF (ap2 );
1084
1085
return (PyObject * )ret ;
1085
1086
1086
- fail :
1087
+ fail :
1088
+ Py_XDECREF (ap1 );
1089
+ Py_XDECREF (ap2 );
1090
+ Py_XDECREF (ret );
1091
+ return NULL ;
1092
+ }
1093
+
1094
+
1095
+ /*NUMPY_API
1096
+ * matmul
1097
+ *
1098
+ * Implements the protocol used by the '@' operator defined in PEP 364.
1099
+ *
1100
+ *
1101
+ * in1: Left hand side operand
1102
+ * in2: Right hand side operand
1103
+ * out: Either NULL, or an array into which the output should be placed.
1104
+ *
1105
+ * Returns NULL on error.
1106
+ */
1107
+ NPY_NO_EXPORT PyObject *
1108
+ PyArray_Matmul (PyObject * in1 , PyObject * in2 , PyArrayObject * out )
1109
+ {
1110
+ PyArrayObject * ap1 , * ap2 , * ret = NULL ;
1111
+ NPY_ORDER order = NPY_KEEPORDER ;
1112
+ NPY_CASTING casting = NPY_SAFE_CASTING ;
1113
+ PyArray_Descr * dtype ;
1114
+ int typenum ;
1115
+ char * subscripts ;
1116
+ PyArrayObject * ops [2 ];
1117
+ NPY_BEGIN_THREADS_DEF ;
1118
+
1119
+ dtype = PyArray_DescrFromObject (in1 , NULL );
1120
+ dtype = PyArray_DescrFromObject (in2 , dtype );
1121
+ if (dtype == NULL ) {
1122
+ PyErr_SetString (PyExc_ValueError ,
1123
+ "Cannot find a common data type." );
1124
+ return NULL ;
1125
+ }
1126
+
1127
+ Py_INCREF (dtype );
1128
+ ap1 = (PyArrayObject * )PyArray_FromAny (in1 , dtype , 0 , 0 ,
1129
+ NPY_ARRAY_ALIGNED , NULL );
1130
+ if (ap1 == NULL ) {
1131
+ Py_DECREF (dtype );
1132
+ return NULL ;
1133
+ }
1134
+ ap2 = (PyArrayObject * )PyArray_FromAny (in2 , dtype , 0 , 0 ,
1135
+ NPY_ARRAY_ALIGNED , NULL );
1136
+ if (ap2 == NULL ) {
1137
+ Py_DECREF (ap1 );
1138
+ return NULL ;
1139
+ }
1140
+
1141
+ if (PyArray_NDIM (ap1 ) == 0 || PyArray_NDIM (ap2 ) == 0 ) {
1142
+ /* Scalars are rejected */
1143
+ PyErr_SetString (PyExc_TypeError ,
1144
+ "Scalar operands are not allowed, use '*' instead" );
1145
+ return NULL ;
1146
+ }
1147
+
1148
+ typenum = dtype -> type_num ;
1149
+ #if defined(HAVE_CBLAS )
1150
+ if (PyArray_NDIM (ap1 ) <= 2 && PyArray_NDIM (ap2 ) <= 2 &&
1151
+ (NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum ||
1152
+ NPY_FLOAT == typenum || NPY_CFLOAT == typenum )) {
1153
+ return cblas_matrixproduct (typenum , ap1 , ap2 , out );
1154
+ }
1155
+ #endif
1156
+
1157
+ /*
1158
+ * Use einsum for the stacked cases. This is a quick implementation
1159
+ * to avoid setting up the proper iterators.
1160
+ */
1161
+ ops [0 ] = ap1 ;
1162
+ ops [1 ] = ap2 ;
1163
+ if (PyArray_NDIM (ap1 ) == 1 && PyArray_NDIM (ap2 ) == 1 ) {
1164
+ /* vector vector */
1165
+ subscripts = "i, i" ;
1166
+ }
1167
+ else if (PyArray_NDIM (ap1 ) == 1 ) {
1168
+ /* vector matrix */
1169
+ subscripts = "i, ...ij" ;
1170
+ }
1171
+ else if (PyArray_NDIM (ap2 ) == 1 ) {
1172
+ /* matrix vector */
1173
+ subscripts = "...i, i" ;
1174
+ }
1175
+ else {
1176
+ /* matrix * matrix */
1177
+ subscripts = "...ij, ...jk" ;
1178
+ }
1179
+
1180
+ ret = PyArray_EinsteinSum (subscripts , 2 , ops , dtype , order , casting , out );
1181
+ if (ret == NULL ) {
1182
+ goto fail ;
1183
+ }
1184
+ Py_DECREF (ap1 );
1185
+ Py_DECREF (ap2 );
1186
+ return (PyObject * )ret ;
1187
+
1188
+ fail :
1087
1189
Py_XDECREF (ap1 );
1088
1190
Py_XDECREF (ap2 );
1089
1191
Py_XDECREF (ret );
@@ -1844,7 +1946,7 @@ array_copyto(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
1844
1946
1845
1947
Py_RETURN_NONE ;
1846
1948
1847
- fail :
1949
+ fail :
1848
1950
Py_XDECREF (src );
1849
1951
Py_XDECREF (wheremask );
1850
1952
return NULL ;
@@ -1887,7 +1989,7 @@ array_empty(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
1887
1989
PyDimMem_FREE (shape .ptr );
1888
1990
return (PyObject * )ret ;
1889
1991
1890
- fail :
1992
+ fail :
1891
1993
Py_XDECREF (typecode );
1892
1994
PyDimMem_FREE (shape .ptr );
1893
1995
return NULL ;
@@ -1918,7 +2020,7 @@ array_empty_like(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
1918
2020
1919
2021
return (PyObject * )ret ;
1920
2022
1921
- fail :
2023
+ fail :
1922
2024
Py_XDECREF (prototype );
1923
2025
Py_XDECREF (dtype );
1924
2026
return NULL ;
@@ -2041,7 +2143,7 @@ array_zeros(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kwds)
2041
2143
PyDimMem_FREE (shape .ptr );
2042
2144
return (PyObject * )ret ;
2043
2145
2044
- fail :
2146
+ fail :
2045
2147
Py_XDECREF (typecode );
2046
2148
PyDimMem_FREE (shape .ptr );
2047
2149
return (PyObject * )ret ;
@@ -2369,6 +2471,50 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
2369
2471
}
2370
2472
2371
2473
2474
+ static PyObject *
2475
+ array_matmul (PyObject * NPY_UNUSED (m ), PyObject * args , PyObject * kwds )
2476
+ {
2477
+ static PyUFuncObject * cached_npy_matmul = NULL ;
2478
+ int errval ;
2479
+ PyObject * override = NULL ;
2480
+ PyObject * v , * a , * o = NULL ;
2481
+ char * kwlist [] = {"a" , "b" , "out" , NULL };
2482
+
2483
+ if (cached_npy_matmul == NULL ) {
2484
+ PyObject * module , * dict , * matmul ;
2485
+
2486
+ module = PyImport_ImportModule ("numpy.core.multiarray" );
2487
+ dict = PyModule_GetDict (module );
2488
+ matmul = PyDict_GetItemString (dict , "matmul" );
2489
+ cached_npy_matmul = (PyUFuncObject * )matmul ;
2490
+ Py_INCREF (cached_npy_matmul );
2491
+ Py_DECREF (module );
2492
+ }
2493
+
2494
+ errval = PyUFunc_CheckOverride (cached_npy_matmul , "__call__" ,
2495
+ args , kwds , & override , 2 );
2496
+ if (errval ) {
2497
+ return NULL ;
2498
+ }
2499
+ else if (override ) {
2500
+ return override ;
2501
+ }
2502
+
2503
+ if (!PyArg_ParseTupleAndKeywords (args , kwds , "OO|O" , kwlist , & a , & v , & o )) {
2504
+ return NULL ;
2505
+ }
2506
+ if (o == Py_None ) {
2507
+ o = NULL ;
2508
+ }
2509
+ if (o != NULL && !PyArray_Check (o )) {
2510
+ PyErr_SetString (PyExc_TypeError ,
2511
+ "'out' must be an array" );
2512
+ return NULL ;
2513
+ }
2514
+ return PyArray_Matmul (a , v , (PyArrayObject * )o );
2515
+ }
2516
+
2517
+
2372
2518
static int
2373
2519
einsum_sub_op_from_str (PyObject * args , PyObject * * str_obj , char * * subscripts ,
2374
2520
PyArrayObject * * op )
@@ -2862,7 +3008,7 @@ array__reconstruct(PyObject *NPY_UNUSED(dummy), PyObject *args)
2862
3008
2863
3009
return ret ;
2864
3010
2865
- fail :
3011
+ fail :
2866
3012
evil_global_disable_warn_O4O8_flag = 0 ;
2867
3013
2868
3014
Py_XDECREF (dtype );
@@ -3090,7 +3236,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
3090
3236
return ret ;
3091
3237
}
3092
3238
3093
- fail :
3239
+ fail :
3094
3240
Py_DECREF (arr );
3095
3241
Py_XDECREF (ax );
3096
3242
Py_XDECREF (ay );
@@ -3936,6 +4082,9 @@ static struct PyMethodDef array_module_methods[] = {
3936
4082
{"vdot" ,
3937
4083
(PyCFunction )array_vdot ,
3938
4084
METH_VARARGS | METH_KEYWORDS , NULL },
4085
+ {"matmul" ,
4086
+ (PyCFunction )array_matmul ,
4087
+ METH_VARARGS | METH_KEYWORDS , NULL },
3939
4088
{"einsum" ,
3940
4089
(PyCFunction )array_einsum ,
3941
4090
METH_VARARGS |METH_KEYWORDS , NULL },
0 commit comments