8000 Merge pull request #14504 from charris/backport-14464 · numpy/numpy@4bd4d98 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4bd4d98

Browse files
authored
Merge pull request #14504 from charris/backport-14464
BUG: add a specialized loop for boolean matmul.
2 parents e75e878 + 2e6c104 commit 4bd4d98

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

numpy/core/src/umath/matmul.c.src

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,16 +196,14 @@ NPY_NO_EXPORT void
196196
* FLOAT, DOUBLE, HALF,
197197
* CFLOAT, CDOUBLE, CLONGDOUBLE,
198198
* UBYTE, USHORT, UINT, ULONG, ULONGLONG,
199-
* BYTE, SHORT, INT, LONG, LONGLONG,
200-
* BOOL#
199+
* BYTE, SHORT, INT, LONG, LONGLONG#
201200
* #typ = npy_longdouble,
202201
* npy_float,npy_double,npy_half,
203202
* npy_cfloat, npy_cdouble, npy_clongdouble,
204203
* npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong,
205-
* npy_byte, npy_short, npy_int, npy_long, npy_longlong,
206-
* npy_bool#
207-
* #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*11#
208-
* #IS_HALF = 0, 0, 0, 1, 0*14#
204+
* npy_byte, npy_short, npy_int, npy_long, npy_longlong#
205+
* #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*10#
206+
* #IS_HALF = 0, 0, 0, 1, 0*13#
209207
*/
210208

211209
NPY_NO_EXPORT void
@@ -266,7 +264,44 @@ NPY_NO_EXPORT void
266264
}
267265

268266
/**end repeat**/
267+
NPY_NO_EXPORT void
268+
BOOL_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
269+
void *_ip2, npy_intp is2_n, npy_intp is2_p,
270+
void *_op, npy_intp os_m, npy_intp os_p,
271+
npy_intp dm, npy_intp dn, npy_intp dp)
272+
273+
{
274+
npy_intp m, n, p;
275+
npy_intp ib2_p, ob_p;
276+
char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op;
269277

278+
ib2_p = is2_p * dp;
279+
ob_p = os_p * dp;
280+
281+
for (m = 0; m < dm; m++) {
282+
for (p = 0; p < dp; p++) {
283+
char *ip1tmp = ip1;
284+
char *ip2tmp = ip2;
285+
*(npy_bool *)op = NPY_FALSE;
286+
for (n = 0; n < dn; n++) {
287+
npy_bool val1 = (*(npy_bool *)ip1tmp);
288+
npy_bool val2 = (*(npy_bool *)ip2tmp);
289+
if (val1 != 0 && val2 != 0) {
290+
*(npy_bool *)op = NPY_TRUE;
291+
break;
292+
}
293+
ip2tmp += is2_n;
294+
ip1tmp += is1_n;
295+
}
296+
op += os_p;
297+
ip2 += is2_p;
298+
}
299+
op -= ob_p;
300+
ip2 -= ib2_p;
301+
ip1 += is1_m;
302+
op += os_m;
303+
}
304+
}
270305

271306
NPY_NO_EXPORT void
272307
OBJECT_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,

numpy/core/tests/test_multiarray.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6251,6 +6251,23 @@ def __mul__(self, other):
62516251
with assert_raises(TypeError):
62526252
b = np.matmul(a, a)
62536253

6254+
def test_matmul_bool(self):
6255+
# gh-14439
6256+
a = np.array([[1, 0],[1, 1]], dtype=bool)
6257+
assert np.max(a.view(np.uint8)) == 1
6258+
b = np.matmul(a, a)
6259+
# matmul with boolean output should always be 0, 1
6260+
assert np.max(b.view(np.uint8)) == 1
6261+
6262+
rg = np.random.default_rng(np.random.PCG64(43))
6263+
d = rg.integers(2, size=4*5, dtype=np.int8)
6264+
d = d.reshape(4, 5) > 0
6265+
out1 = np.matmul(d, d.reshape(5, 4))
6266+
out2 = np.dot(d, d.reshape(5, 4))
6267+
assert_equal(out1, out2)
6268+
6269+
c = np.matmul(np.zeros((2, 0), dtype=bool), np.zeros(0, dtype=bool))
6270+
assert not np.any(c)
62546271

62556272

62566273
if sys.version_info[:2] >= (3, 5):

0 commit comments

Comments
 (0)
0