8000 Merge pull request #29191 from charris/backport-29179 · numpy/numpy@270eef7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 270eef7

Browse files
authored
Merge pull request #29191 from charris/backport-29179
BUG: fix matmul with transposed out arg (#29179)
2 parents 35b19a4 + cb0ff22 commit 270eef7

File tree

4 files changed

+17
-1
lines changed

4 files changed

+17
-1
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Fix bug in ``matmul`` for non-contiguous out kwarg parameter
2+
------------------------------------------------------------
3+
In some cases, if ``out`` was non-contiguous, ``np.matmul`` would cause
4+
memory corruption or a c-level assert. This was new to v2.3.0 and fixed in v2.3.1.

doc/source/release/2.3.0-notes.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,12 @@ the best performance.
414414

415415
(`gh-28769 <https://github.com/numpy/numpy/pull/28769>`__)
416416

417+
Performance improvements for ``np.matmul``
418+
------------------------------------------
419+
Enable using BLAS for ``matmul`` even when operands are non-contiguous by copying
420+
if needed.
421+
422+
(`gh-23752 <https://github.com/numpy/numpy/pull/23752>`__)
417423

418424
Changes
419425
=======

numpy/_core/src/umath/matmul.c.src

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ NPY_NO_EXPORT void
596596
* Use transpose equivalence:
597597
* matmul(a, b, o) == matmul(b.T, a.T, o.T)
598598
*/
599-
if (o_f_blasable) {
599+
if (o_transpose) {
600600
@TYPE@_matmul_matrixmatrix(
601601
ip2_, is2_p_, is2_n_,
602602
ip1_, is1_n_, is1_m_,

numpy/_core/tests/test_multiarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7317,6 +7317,12 @@ def test_dot_equivalent(self, args):
73177317
r3 = np.matmul(args[0].copy(), args[1].copy())
73187318
assert_equal(r1, r3)
73197319

7320+
# matrix matrix, issue 29164
7321+
if [len(args[0].shape), len(args[1].shape)] == [2, 2]:
7322+
out_f = np.zeros((r2.shape[0] * 2, r2.shape[1] * 2), order='F')
7323+
r4 = np.matmul(*args, out=out_f[::2, ::2])
7324+
assert_equal(r2, r4)
7325+
73207326
def test_matmul_object(self):
73217327
import fractions
73227328

0 commit comments

Comments
 (0)
0