@@ -64,6 +64,109 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
64
64
/* Only here for API compatibility */
65
65
NPY_NO_EXPORT PyTypeObject PyBigArray_Type ;
66
66
67
+
68
+ /*
69
+ * matmul
70
+ *
71
+ * Implements the protocol used by the '@' operator defined in PEP 364.
72
+ * Not in the NUMPY API at this time, maybe later.
73
+ *
74
+ *
75
+ * in1: Left hand side operand
76
+ * in2: Right hand side operand
77
+ * out: Either NULL, or an array into which the output should be placed.
78
+ *
79
+ * Returns NULL on error.
80
+ */
81
+ static PyObject *
82
+ PyArray_Matmul (PyObject * in1 , PyObject * in2 , PyArrayObject * out )
83
+ {
84
+ PyArrayObject * ap1 , * ap2 , * ret = NULL ;
85
+ NPY_ORDER order = NPY_KEEPORDER ;
86
+ NPY_CASTING casting = NPY_SAFE_CASTING ;
87
+ PyArray_Descr * dtype ;
88
+ int typenum ;
89
+ char * subscripts ;
90
+ PyArrayObject * ops [2 ];
91
+ NPY_BEGIN_THREADS_DEF ;
92
+
93
+ dtype = PyArray_DescrFromObject (in1 , NULL );
94
+ dtype = PyArray_DescrFromObject (in2 , dtype );
95
+ if (dtype == NULL ) {
96
+ PyErr_SetString (PyExc_ValueError ,
97
+ "Cannot find a common data type." );
98
+ return NULL ;
99
+ }
100
+
101
+ Py_INCREF (dtype );
102
+ ap1 = (PyArrayObject * )PyArray_FromAny (in1 , dtype , 0 , 0 ,
103
+ NPY_ARRAY_ALIGNED , NULL );
104
+ if (ap1 == NULL ) {
105
+ Py_DECREF (dtype );
106
+ return NULL ;
107
+ }
108
+ ap2 = (PyArrayObject * )PyArray_FromAny (in2 , dtype , 0 , 0 ,
109
+ NPY_ARRAY_ALIGNED , NULL );
110
+ if (ap2 == NULL ) {
111
+ Py_DECREF (ap1 );
112
+ return NULL ;
113
+ }
114
+
115
+ if (PyArray_NDIM (ap1 ) == 0 || PyArray_NDIM (ap2 ) == 0 ) {
116
+ /* Scalars are rejected */
117
+ PyErr_SetString (PyExc_ValueError ,
118
+ "Scalar operands are not allowed, use '*' instead" );
119
+ return NULL ;
120
+ }
121
+
122
+ typenum = dtype -> type_num ;
123
+ #if defined(HAVE_CBLAS )
124
+ if (PyArray_NDIM (ap1 ) <= 2 && PyArray_NDIM (ap2 ) <= 2 &&
125
+ (NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum ||
126
+ NPY_FLOAT == typenum || NPY_CFLOAT == typenum )) {
127
+ return cblas_matrixproduct (typenum , ap1 , ap2 , out );
128
+ }
129
+ #endif
130
+
131
+ /*
132
+ * Use einsum for the stacked cases. This is a quick implementation
133
+ * to avoid setting up the proper iterators.
134
+ */
135
+ ops [0 ] = ap1 ;
136
+ ops [1 ] = ap2 ;
137
+ if (PyArray_NDIM (ap1 ) == 1 && PyArray_NDIM (ap2 ) == 1 ) {
138
+ /* vector vector */
139
+ subscripts = "i, i" ;
140
+ }
141
+ else if (PyArray_NDIM (ap1 ) == 1 ) {
142
+ /* vector matrix */
143
+ subscripts = "i, ...ij" ;
144
+ }
145
+ else if (PyArray_NDIM (ap2 ) == 1 ) {
146
+ /* matrix vector */
147
+ subscripts = "...i, i" ;
148
+ }
149
+ else {
150
+ /* matrix * matrix */
151
+ subscripts = "...ij, ...jk" ;
152
+ }
153
+
154
+ ret = PyArray_EinsteinSum (subscripts , 2 , ops , dtype , order , casting , out );
155
+ if (ret == NULL ) {
156
+ goto fail ;
157
+ }
158
+ Py_DECREF (ap1 );
159
+ Py_DECREF (ap2 );
160
+ return (PyObject * )ret ;
161
+
162
+ fail :
163
+ Py_XDECREF (ap1 );
164
+ Py_XDECREF (ap2 );
165
+ Py_XDECREF (ret );
166
+ return NULL ;
167
+ }
168
+
169
+
67
170
/*NUMPY_API
68
171
* Get Priority from object
69
172
*/
@@ -1092,107 +1195,6 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
1092
1195
}
1093
1196
1094
1197
1095
- /*NUMPY_API
1096
- * matmul
1097
- *
1098
- * Implements the protocol used by the '@' operator defined in PEP 364.
1099
- *
1100
- *
1101
- * in1: Left hand side operand
1102
- * in2: Right hand side operand
1103
- * out: Either NULL, or an array into which the output should be placed.
1104
- *
1105
- * Returns NULL on error.
1106
- */
1107
- NPY_NO_EXPORT PyObject *
1108
- PyArray_Matmul (PyObject * in1 , PyObject * in2 , PyArrayObject * out )
1109
- {
1110
- PyArrayObject * ap1 , * ap2 , * ret = NULL ;
1111
- NPY_ORDER order = NPY_KEEPORDER ;
1112
- NPY_CASTING casting = NPY_SAFE_CASTING ;
1113
- PyArray_Descr * dtype ;
1114
- int typenum ;
1115
- char * subscripts ;
1116
- PyArrayObject * ops [2 ];
1117
- NPY_BEGIN_THREADS_DEF ;
1118
-
1119
- dtype = PyArray_DescrFromObject (in1 , NULL );
1120
- dtype = PyArray_DescrFromObject (in2 , dtype );
1121
- if (dtype == NULL ) {
1122
- PyErr_SetString (PyExc_ValueError ,
1123
- "Cannot find a common data type." );
1124
- return NULL ;
1125
- }
1126
-
1127
- Py_INCREF (dtype );
1128
- ap1 = (PyArrayObject * )PyArray_FromAny (in1 , dtype , 0 , 0 ,
1129
- NPY_ARRAY_ALIGNED , NULL );
1130
- if (ap1 == NULL ) {
1131
- Py_DECREF (dtype );
1132
- return NULL ;
1133
- }
1134
- ap2 = (PyArrayObject * )PyArray_FromAny (in2 , dtype , 0 , 0 ,
1135
- NPY_ARRAY_ALIGNED , NULL );
1136
- if (ap2 == NULL ) {
1137
- Py_DECREF (ap1 );
1138
- return NULL ;
1139
- }
1140
-
1141
- if (PyArray_NDIM (ap1 ) == 0 || PyArray_NDIM (ap2 ) == 0 ) {
1142
- /* Scalars are rejected */
1143
- PyErr_SetString (PyExc_TypeError ,
1144
- "Scalar operands are not allowed, use '*' instead" );
1145
- return NULL ;
1146
- }
1147
-
1148
- typenum = dtype -> type_num ;
1149
- #if defined(HAVE_CBLAS )
1150
- if (PyArray_NDIM (ap1 ) <= 2 && PyArray_NDIM (ap2 ) <= 2 &&
1151
- (NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum ||
1152
- NPY_FLOAT == typenum || NPY_CFLOAT == typenum )) {
1153
- return cblas_matrixproduct (typenum , ap1 , ap2 , out );
1154
- }
1155
- #endif
1156
-
1157
- /*
1158
- * Use einsum for the stacked cases. This is a quick implementation
1159
- * to avoid setting up the proper iterators.
1160
- */
1161
- ops [0 ] = ap1 ;
1162
- ops [1 ] = ap2 ;
1163
- if (PyArray_NDIM (ap1 ) == 1 && PyArray_NDIM (ap2 ) == 1 ) {
1164
- /* vector vector */
1165
- subscripts = "i, i" ;
1166
- }
1167
- else if (PyArray_NDIM (ap1 ) == 1 ) {
1168
- /* vector matrix */
1169
- subscripts = "i, ...ij" ;
1170
- }
1171
- else if (PyArray_NDIM (ap2 ) == 1 ) {
1172
- /* matrix vector */
1173
- subscripts = "...i, i" ;
1174
- }
1175
- else {
1176
- /* matrix * matrix */
1177
- subscripts = "...ij, ...jk" ;
1178
- }
1179
-
1180
- ret = PyArray_EinsteinSum (subscripts , 2 , ops , dtype , order , casting , out );
1181
- if (ret == NULL ) {
1182
- goto fail ;
1183
- }
1184
- Py_DECREF (ap1 );
1185
- Py_DECREF (ap2 );
1186
- return (PyObject * )ret ;
1187
-
1188
- fail :
1189
- Py_XDECREF (ap1 );
1190
- Py_XDECREF (ap2 );
1191
- Py_XDECREF (ret );
1192
- return NULL ;
1193
- }
1194
-
1195
-
1196
1198
/*NUMPY_API
1197
1199
* Copy and Transpose
1198
1200
*
@@ -2318,7 +2320,7 @@ array_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
2318
2320
return PyArray_Return ((PyArrayObject * )PyArray_InnerProduct (a0 , b0 ));
2319
2321
}
2320
2322
2321
- static PyObject *
2323
+ NPY_NO_EXPORT PyObject *
2322
2324
array_matrixproduct (PyObject * NPY_UNUSED (dummy ), PyObject * args , PyObject * kwds )
2323
2325
{
2324
2326
static PyUFuncObject * cached_npy_dot = NULL ;
@@ -2471,7 +2473,7 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
2471
2473
}
2472
2474
2473
2475
2474
- static PyObject *
2476
+ NPY_NO_EXPORT PyObject *
2475
2477
array_matmul (PyObject * NPY_UNUSED (m ), PyObject * args , PyObject * kwds )
2476
2478
{
2477
2479
static PyUFuncObject * cached_npy_matmul = NULL ;
0 commit comments