8000 ENH: np.dot: better "matrices not aligned" message · numpy/numpy@d8af083 · GitHub
[go: up one dir, main page]

Skip to content

Commit d8af083

Browse files
committed
ENH: np.dot: better "matrices not aligned" message
1 parent 594b0de commit d8af083

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

numpy/core/blasdot/_dotblas.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "numpy/arrayobject.h"
1010
#include "npy_config.h"
1111
#include "npy_pycompat.h"
12+
#include "common.h"
1213
#include "ufunc_override.h"
1314
#ifndef CBLAS_HEADER
1415
#define CBLAS_HEADER "cblas.h"
@@ -529,7 +530,8 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
529530
l = PyArray_DIM(oap1, PyArray_NDIM(oap1) - 1);
530531

531532
if (PyArray_DIM(oap2, 0) != l) {
532-
PyErr_SetString(PyExc_ValueError, "matrices are not aligned");
533+
not_aligned(PyArray_NDIM(oap1) - 1, 0,
534+
l, PyArray_DIM(oap2, 0));
533535
goto fail;
534536
}
535537
nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2;
@@ -579,7 +581,8 @@ dotblas_matrixproduct(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject* kwa
579581
l = PyArray_DIM(ap1, PyArray_NDIM(ap1) - 1);
580582

581583
if (PyArray_DIM(ap2, 0) != l) {
582-
PyErr_SetString(PyExc_ValueError, "matrices are not aligned");
584+
not_aligned(PyArray_NDIM(ap1) - 1, 0,
585+
l, PyArray_DIM(ap2, 0));
583586
goto fail;
584587
}
585588
nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2;
@@ -1007,7 +1010,8 @@ dotblas_innerproduct(PyObject *NPY_UNUSED(dummy), PyObject *args)
10071010
l = PyArray_DIM(ap1, PyArray_NDIM(ap1)-1);
10081011

10091012
if (PyArray_DIM(ap2, PyArray_NDIM(ap2)-1) != l) {
1010-
PyErr_SetString(PyExc_ValueError, "matrices are not aligned");
1013+
not_aligned(PyArray_NDIM(ap1) - 1, PyArray_NDIM(ap2) - 1,
1014+
l, PyArray_DIM(ap2, PyArray_NDIM(ap2) - 1));
10111015
goto fail;
10121016
}
10131017
nd = PyArray_NDIM(ap1)+PyArray_NDIM(ap2)-2;

numpy/core/src/multiarray/common.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,18 @@ _is_basic_python_type(PyObject * obj)
208208
return 0;
209209
}
210210

211+
/*
212+
* Sets ValueError with "matrices not aligned" message for np.dot and friends
213+
* when shape[i] = m doesn't match shape[j] = n.
214+
*/
215+
static NPY_INLINE void
216+
not_aligned(int i, int j, Py_ssize_t m, Py_ssize_t n)
217+
{
218+
PyErr_Format(PyExc_ValueError,
219+
"matrices are not aligned: shape[%d] (%zd) != shape[%d] (%zd)",
220+
i, m, j, n);
221+
}
222+
211223
#include "ucsnarrow.h"
212224

213225
#endif

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,8 @@ PyArray_InnerProduct(PyObject *op1, PyObject *op2)
841841

842842
l = PyArray_DIMS(ap1)[PyArray_NDIM(ap1) - 1];
843843
if (PyArray_DIMS(ap2)[PyArray_NDIM(ap2) - 1] != l) {
844-
PyErr_SetString(PyExc_ValueError, "matrices are not aligned");
844+
not_aligned(PyArray_NDIM(ap1) - 1, PyArray_NDIM(ap2) - 1,
845+
l, PyArray_DIMS(ap2)[PyArray_NDIM(ap2) - 1]);
845846
goto fail;
846847
}
847848

@@ -961,7 +962,8 @@ PyArray_MatrixProduct2(PyObject *op1, PyObject *op2, PyArrayObject* out)
961962
matchDim = 0;
962963
}
963964
if (PyArray_DIMS(ap2)[matchDim] != l) {
964-
PyErr_SetString(PyExc_ValueError, "objects are not aligned");
965+
not_aligned(PyArray_NDIM(ap1) - 1, matchDim,
966+
l, PyArray_DIMS(ap2)[matchDim]);
965967
goto fail;
966968
}
967969
nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2;

0 commit comments

Comments
 (0)
0