|
18 | 18 | #include "item_selection.h"
|
19 | 19 | #include "conversion_utils.h"
|
20 | 20 | #include "shape.h"
|
| 21 | +#include "multiarraymodule.h" |
21 | 22 |
|
22 | 23 | #include "methods.h"
|
23 | 24 |
|
@@ -1996,58 +1997,6 @@ array_cumprod(PyArrayObject *self, PyObject *args, PyObject *kwds)
|
1996 | 1997 | }
|
1997 | 1998 |
|
1998 | 1999 |
|
1999 |
| -static PyObject * |
2000 |
| -array_dot(PyArrayObject *self, PyObject *args, PyObject *kwds) |
2001 |
| -{ |
2002 |
| - static PyUFuncObject *cached_npy_dot = NULL; |
2003 |
| - int errval; |
2004 |
| - PyObject *override = NULL; |
2005 |
| - PyObject *a = (PyObject *)self, *b, *o = Py_None; |
2006 |
| - PyObject *newargs; |
2007 |
| - PyArrayObject *ret; |
2008 |
| - char* kwlist[] = {"b", "out", NULL }; |
2009 |
| - |
2010 |
| - |
2011 |
| - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &b, &o)) { |
2012 |
| - return NULL; |
2013 |
| - } |
2014 |
| - |
2015 |
| - if (cached_npy_dot == NULL) { |
2016 |
| - PyObject *module = PyImport_ImportModule("numpy.core.multiarray"); |
2017 |
| - cached_npy_dot = (PyUFuncObject*)PyDict_GetItemString( |
2018 |
| - PyModule_GetDict(module), "dot"); |
2019 |
| - |
2020 |
| - Py_INCREF(cached_npy_dot); |
2021 |
| - Py_DECREF(module); |
2022 |
| - } |
2023 |
| - |
2024 |
| - if ((newargs = PyTuple_Pack(3, a, b, o)) == NULL) { |
2025 |
| - return NULL; |
2026 |
| - } |
2027 |
| - errval = PyUFunc_CheckOverride(cached_npy_dot, "__call__", |
2028 |
| - newargs, NULL, &override, 2); |
2029 |
| - Py_DECREF(newargs); |
2030 |
| - |
2031 |
| - if (errval) { |
2032 |
| - return NULL; |
2033 |
| - } |
2034 |
| - else if (override) { |
2035 |
| - return override; |
2036 |
| - } |
2037 |
| - |
2038 |
| - if (o == Py_None) { |
2039 |
| - o = NULL; |
2040 |
| - } |
2041 |
| - if (o != NULL && !PyArray_Check(o)) { |
2042 |
| - PyErr_SetString(PyExc_TypeError, |
2043 |
| - "'out' must be an array"); |
2044 |
| - return NULL; |
2045 |
| - } |
2046 |
| - ret = (PyArrayObject *)PyArray_MatrixProduct2(a, b, (PyArrayObject *)o); |
2047 |
| - return PyArray_Return(ret); |
2048 |
| -} |
2049 |
| - |
2050 |
| - |
2051 | 2000 | static PyObject *
|
2052 | 2001 | array_any(PyArrayObject *self, PyObject *args, PyObject *kwds)
|
2053 | 2002 | {
|
@@ -2317,18 +2266,73 @@ array_newbyteorder(PyArrayObject *self, PyObject *args)
|
2317 | 2266 | }
|
2318 | 2267 |
|
2319 | 2268 |
|
2320 |
| -/*static PyObject **/ |
2321 |
| -/*array_matmul(PyArrayObject *self, PyObject *rhs)*/ |
2322 |
| -/*{*/ |
2323 |
| - /*return PyArray_Matmul(self, rhs, out=NULL);*/ |
2324 |
| -/*}*/ |
| 2269 | +static PyObject * |
| 2270 | +array_dot(PyObject *self, PyObject *args, PyObject *kwds) |
| 2271 | +{ |
| 2272 | + static PyObject *newkwargs = NULL; |
| 2273 | + PyObject *a = (PyObject *)self, *b, *o = Py_None; |
| 2274 | + PyArrayObject *ret; |
| 2275 | + PyObject *newargs; |
| 2276 | + char* kwlist[] = {"b", "out", NULL }; |
| 2277 | + |
| 2278 | + |
| 2279 | + if (newkwargs == NULL) { |
| 2280 | + if ((newkwargs = PyDict_New()) == NULL) { |
| 2281 | + return NULL; |
| 2282 | + } |
| 2283 | + } |
| 2284 | + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O", kwlist, &b, &o)) { |
| 2285 | + return NULL; |
| 2286 | + } |
| 2287 | + if ((newargs = PyTuple_Pack(3, a, b, o)) == NULL) { |
| 2288 | + return NULL; |
| 2289 | + } |
| 2290 | + ret = array_matrixproduct(NULL, newargs, newkwargs); |
| 2291 | + Py_DECREF(newargs); |
| 2292 | + return ret; |
| 2293 | +} |
| 2294 | + |
| 2295 | + |
| 2296 | +static PyObject * |
| 2297 | +array__matmul__(PyObject *self, PyObject *rhs) |
| 2298 | +{ |
| 2299 | + static PyObject *newkwargs = NULL; |
| 2300 | + PyObject *newargs; |
| 2301 | + PyArrayObject *ret; |
| 2302 | + |
| 2303 | + if (newkwargs == NULL) { |
| 2304
F438
td> | + if ((newkwargs = PyDict_New()) == NULL) { |
| 2305 | + return NULL; |
| 2306 | + } |
| 2307 | + } |
| 2308 | + if ((newargs = PyTuple_Pack(2, self, rhs)) == NULL) { |
| 2309 | + return NULL; |
| 2310 | + } |
| 2311 | + ret = array_matmul(NULL, newargs, newkwargs); |
| 2312 | + Py_DECREF(newargs); |
| 2313 | + return ret; |
| 2314 | +} |
| 2315 | + |
2325 | 2316 |
|
| 2317 | +static PyObject * |
| 2318 | +array__rmatmul__(PyObject *self, PyObject *lhs) |
| 2319 | +{ |
| 2320 | + static PyObject *newkwargs = NULL; |
| 2321 | + PyObject *newargs; |
| 2322 | + PyArrayObject *ret; |
2326 | 2323 |
|
2327 |
| -/*static PyObject **/ |
2328 |
| -/*array_rmatmul(PyArrayObject *self, PyObject *lhs)*/ |
2329 |
| -/*{*/ |
2330 |
| - /*return PyArray_Matmul(lhs, self, out=NULL);*/ |
2331 |
| -/*}*/ |
| 2324 | + if (newkwargs == NULL) { |
| 2325 | + if ((newkwargs = PyDict_New()) == NULL) { |
| 2326 | + return NULL; |
| 2327 | + } |
| 2328 | + } |
| 2329 | + if ((newargs = PyTuple_Pack(2, lhs, self)) == NULL) { |
| 2330 | + return NULL; |
| 2331 | + } |
| 2332 | + ret = array_matmul(NULL, newargs, newkwargs); |
| 2333 | + Py_DECREF(newargs); |
| 2334 | + return ret; |
| 2335 | +} |
2332 | 2336 |
|
2333 | 2337 |
|
2334 | 2338 | NPY_NO_EXPORT PyMethodDef array_methods[] = {
|
@@ -2358,12 +2362,12 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
|
2358 | 2362 | METH_VARARGS, NULL},
|
2359 | 2363 |
|
2360 | 2364 | /* for '@' operator in Python >= 3.5 */
|
2361 |
| - /*{"__matmul__",*/ |
2362 |
| - /*(PyCFunction)array_matmul,*/ |
2363 |
| - /*METH_O, NULL},*/ |
2364 |
| - /*{"__rmatmul__",*/ |
2365 |
| - /*(PyCFunction)array_rmatmul,*/ |
2366 |
| - /*METH_O, NULL},*/ |
| 2365 | + {"__matmul__", |
| 2366 | + (PyCFunction)array__matmul__, |
| 2367 | + METH_O, NULL}, |
| 2368 | + {"__rmatmul__", |
| 2369 | + (PyCFunction)array__rmatmul__, |
| 2370 | + METH_O, NULL}, |
2367 | 2371 |
|
2368 | 2372 | /* for Pickling */
|
2369 | 2373 | {"__reduce__",
|
|
0 commit comments