8000 BUG: Fix vdot for uncontiguous arrays. · numpy/numpy@7f434e7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7f434e7

Browse files
sebergcharris
authored andcommitted
BUG: Fix vdot for uncontiguous arrays.
Note that using Newshape also means that less copying is done in principle, because ravel will always return a contiguous array.
1 parent 00ca7ea commit 7f434e7

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2254,8 +2254,10 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
22542254
{
22552255
int typenum;
22562256
char *ip1, *ip2, *op;
2257-
npy_intp n, stride;
2257+
npy_intp n, stride1, stride2;
22582258
PyObject *op1, *op2;
2259+
npy_intp newdimptr[1] = {-1};
2260+
PyArray_Dims newdims = {newdimptr, 1};
22592261
PyArrayObject *ap1 = NULL, *ap2 = NULL, *ret = NULL;
22602262
PyArray_Descr *type;
22612263
PyArray_DotFunc *vdot;
@@ -2279,7 +2281,8 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
22792281
Py_DECREF(type);
22802282
goto fail;
22812283
}
2282-
op1 = PyArray_Ravel(ap1, NPY_CORDER);
2284+
2285+
op1 = PyArray_Newshape(ap1, &newdims, NPY_CORDER);
22832286
if (op1 == NULL) {
22842287
Py_DECREF(type);
22852288
goto fail;
@@ -2291,7 +2294,7 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
22912294
if (ap2 == NULL) {
22922295
goto fail;
22932296
}
2294-
op2 = PyArray_Ravel(ap2, NPY_CORDER);
2297+
op2 = PyArray_Newshape(ap2, &newdims, NPY_CORDER);
22952298
if (op2 == NULL) {
22962299
goto fail;
22972300
}
@@ -2311,7 +2314,8 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
23112314
}
23122315

23132316
n = PyArray_DIM(ap1, 0);
2314-
stride = type->elsize;
2317+
stride1 = PyArray_STRIDE(ap1, 0);
2318+
stride2 = PyArray_STRIDE(ap2, 0);
23152319
ip1 = PyArray_DATA(ap1);
23162320
ip2 = PyArray_DATA(ap2);
23172321
op = PyArray_DATA(ret);
@@ -2339,11 +2343,11 @@ array_vdot(PyObject *NPY_UNUSED(dummy), PyObject *args)
23392343
}
23402344

23412345
if (n < 500) {
2342-
vdot(ip1, stride, ip2, stride, op, n, NULL);
2346+
vdot(ip1, stride1, ip2, stride2, op, n, NULL);
23432347
}
23442348
else {
23452349
NPY_BEGIN_THREADS_DESCR(type);
2346-
vdot(ip1, stride, ip2, stride, op, n, NULL);
2350+
vdot(ip1, stride1, ip2, stride2, op, n, NULL);
23472351
NPY_END_THREADS_DESCR(type);
23482352
}
23492353

numpy/core/tests/test_multiarray.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3970,6 +3970,28 @@ def test_vdot_array_order(self):
39703970
assert_equal(np.vdot(b, a), res)
39713971
assert_equal(np.vdot(b, b), res)
39723972

3973+
def test_vdot_uncontiguous(self):
3974+
for size in [2, 1000]:
3975+
# Different sizes match different branches in vdot.
3976+
a = np.zeros((size, 2, 2))
3977+
b = np.zeros((size, 2, 2))
3978+
a[:, 0, 0] = 8000 np.arange(size)
3979+
b[:, 0, 0] = np.arange(size) + 1
3980+
# Make a and b uncontiguous:
3981+
a = a[..., 0]
3982+
b = b[..., 0]
3983+
3984+
assert_equal(np.vdot(a, b),
3985+
np.vdot(a.flatten(), b.flatten()))
3986+
assert_equal(np.vdot(a, b.copy()),
3987+
np.vdot(a.flatten(), b.flatten()))
3988+
assert_equal(np.vdot(a.copy(), b),
3989+
np.vdot(a.flatten(), b.flatten()))
3990+
assert_equal(np.vdot(a.copy('F'), b),
3991+
np.vdot(a.flatten(), b.flatten()))
3992+
assert_equal(np.vdot(a, b.copy('F')),
3993+
np.vdot(a.flatten(), b.flatten()))
3994+
39733995

39743996
class TestDot(TestCase):
39753997
def setUp(self):

0 commit comments

Comments
 (0)
0