8000 MAINT: use tmp pointers to allow early break; add tests. · numpy/numpy@3b7fd8a · GitHub
[go: up one dir, main page]

Skip to content

Commit 3b7fd8a

Browse files
mattipcharris
authored andcommitted
MAINT: use tmp pointers to allow early break; add tests.
1 parent 084b46c commit 3b7fd8a

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

numpy/core/src/umath/matmul.c.src

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -282,18 +282,19 @@ BOOL_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
282282

283283
for (m = 0; m < dm; m++) {
284284
for (p = 0; p < dp; p++) {
285+
npy_bool *ip1tmp = ip1;
286+
npy_bool *ip2tmp = ip2;
285287
*(npy_bool *)op = NPY_FALSE;
286288
for (n = 0; n < dn; n++) {
287-
npy_bool val1 = (*(npy_bool *)ip1);
288-
npy_bool val2 = (*(npy_bool *)ip2);
289+
npy_bool val1 = (*(npy_bool *)ip1tmp);
290+
npy_bool val2 = (*(npy_bool *)ip2tmp);
289291
if (val1 != 0 && val2 != 0) {
290292
*(npy_bool *)op = NPY_TRUE;
293+
break;
291294
}
292-
ip2 += is2_n;
293-
ip1 += is1_n;
295+
ip2tmp += is2_n;
296+
ip1tmp += is1_n;
294297
}
295-
ip1 -= ib1_n;
296-
ip2 -= ib2_n;
297298
op += os_p;
298299
ip2 += is2_p;
299300
}

numpy/core/tests/test_multiarray.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6254,10 +6254,20 @@ def __mul__(self, other):
62546254
def test_matmul_bool(self):
62556255
# gh-14439
62566256
a = np.array([[1, 0],[1, 1]], dtype=bool)
6257-
assert np.max(a.view(np.int8)) == 1
6257+
assert np.max(a.view(np.uint8)) == 1
62586258
b = np.matmul(a, a)
62596259
# matmul with boolean output should always be 0, 1
6260-
assert np.max(b.view(np.int8)) == 1
6260+
assert np.max(b.view(np.uint8)) == 1
6261+
6262+
rg = np.random.default_rng(np.random.PCG64(43))
6263+
d = rg.integers(2, size=4*5, dtype=np.int8)
6264+
d = d.reshape(4, 5) > 0
6265+
out1 = np.matmul(d, d.reshape(5, 4))
6266+
out2 = np.dot(d, d.reshape(5, 4))
6267+
assert_equal(out1, out2)
6268+
6269+
c = np.matmul(np.zeros((2, 0), dtype=bool), np.zeros(0, dtype=bool))
6270+
assert not np.any(c)
62616271

62626272

3E56
62636273
if sys.version_info[:2] >= (3, 5):

0 commit comments

Comments
 (0)
0