diff --git a/doc/release/1.16.6-notes.rst b/doc/release/1.16.6-notes.rst new file mode 100644 index 000000000000..4eafdb992d1f --- /dev/null +++ b/doc/release/1.16.6-notes.rst @@ -0,0 +1,53 @@ +========================== +NumPy 1.16.6 Release Notes +========================== + + +Highlights +========== + + +New functions +============= + +Allow matmul (`@` operator) to work with object arrays. +------------------------------------------------------- +This is an enhancement that was added in NumPy 1.17 and seems reasonable to +include in the LTS 1.16 release series. + + +New deprecations +================ + + +Expired deprecations +==================== + + +Future changes +============== + + +Compatibility notes +=================== + +Fix regression in matmul (`@` operator) for boolean types +--------------------------------------------------------- +Booleans were being treated as integers rather than booleans, +which was a regression from previous behavior. + + +C API changes +============= + + +New Features +============ + + +Improvements +============ + + +Changes +======= diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index 0fac9b05eeff..daf5949d0625 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -911,6 +911,7 @@ def english_upper(s): docstrings.get('numpy.core.umath.matmul'), "PyUFunc_SimpleBinaryOperationTypeResolver", TD(notimes_or_obj), + TD(O), signature='(n?,k),(k,m?)->(n?,m?)', ), } diff --git a/numpy/core/src/umath/matmul.c.src b/numpy/core/src/umath/matmul.c.src index 0cb3c82ad228..bc00d3562d0f 100644 --- a/numpy/core/src/umath/matmul.c.src +++ b/numpy/core/src/umath/matmul.c.src @@ -196,16 +196,14 @@ NPY_NO_EXPORT void * FLOAT, DOUBLE, HALF, * CFLOAT, CDOUBLE, CLONGDOUBLE, * UBYTE, USHORT, UINT, ULONG, ULONGLONG, - * BYTE, SHORT, INT, LONG, LONGLONG, - * BOOL# + * BYTE, SHORT, INT, LONG, LONGLONG# * #typ = npy_longdouble, * npy_float,npy_double,npy_half, * npy_cfloat, npy_cdouble, npy_clongdouble, * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong, - * npy_byte, npy_short, npy_int, npy_long, npy_longlong, - * npy_bool# - * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*11# - * #IS_HALF = 0, 0, 0, 1, 0*14# + * npy_byte, npy_short, npy_int, npy_long, npy_longlong# + * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*10# + * #IS_HALF = 0, 0, 0, 1, 0*13# */ NPY_NO_EXPORT void @@ -213,7 +211,6 @@ NPY_NO_EXPORT void void *_ip2, npy_intp is2_n, npy_intp is2_p, void *_op, npy_intp os_m, npy_intp os_p, npy_intp dm, npy_intp dn, npy_intp dp) - { npy_intp m, n, p; npy_intp ib1_n, ib2_n, ib2_p, ob_p; @@ -266,20 +263,126 @@ NPY_NO_EXPORT void } /**end repeat**/ +NPY_NO_EXPORT void +BOOL_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n, + void *_ip2, npy_intp is2_n, npy_intp is2_p, + void *_op, npy_intp os_m, npy_intp os_p, + npy_intp dm, npy_intp dn, npy_intp dp) +{ + npy_intp m, n, p; + npy_intp ib2_p, ob_p; + char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op; + + ib2_p = is2_p * dp; + ob_p = os_p * dp; + + for (m = 0; m < dm; m++) { + for (p = 0; p < dp; p++) { + char *ip1tmp = ip1; + char *ip2tmp = ip2; + *(npy_bool *)op = NPY_FALSE; + for (n = 0; n < dn; n++) { + npy_bool val1 = (*(npy_bool *)ip1tmp); + npy_bool val2 = (*(npy_bool *)ip2tmp); + if (val1 != 0 && val2 != 0) { + *(npy_bool *)op = NPY_TRUE; + break; + } + ip2tmp += is2_n; + ip1tmp += is1_n; + } + op += os_p; + ip2 += is2_p; + } + op -= ob_p; + ip2 -= ib2_p; + ip1 += is1_m; + op += os_m; + } +} + +NPY_NO_EXPORT void +OBJECT_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n, + void *_ip2, npy_intp is2_n, npy_intp is2_p, + void *_op, npy_intp os_m, npy_intp os_p, + npy_intp dm, npy_intp dn, npy_intp dp) +{ + char *ip1 = (char *)_ip1, *ip2 = (char *)_ip2, *op = (char *)_op; + + npy_intp ib1_n = is1_n * dn; + npy_intp ib2_n = is2_n * dn; + npy_intp ib2_p = is2_p * dp; + npy_intp ob_p = os_p * dp; + npy_intp m, p, n; + + PyObject *product, *sum_of_products = NULL; + + for (m = 0; m < dm; m++) { + for (p = 0; p < dp; p++) { + if ( 0 == dn ) { + sum_of_products = PyLong_FromLong(0); + if (sum_of_products == NULL) { + return; + } + } + + for (n = 0; n < dn; n++) { + PyObject *obj1 = *(PyObject**)ip1, *obj2 = *(PyObject**)ip2; + if (obj1 == NULL) { + obj1 = Py_None; + } + if (obj2 == NULL) { + obj2 = Py_None; + } + + product = PyNumber_Multiply(obj1, obj2); + if (product == NULL) { + Py_XDECREF(sum_of_products); + return; + } + + if (n == 0) { + sum_of_products = product; + } + else { + Py_SETREF(sum_of_products, PyNumber_Add(sum_of_products, product)); + Py_DECREF(product); + if (sum_of_products == NULL) { + return; + } + } + + ip2 += is2_n; + ip1 += is1_n; + } + + *((PyObject **)op) = sum_of_products; + ip1 -= ib1_n; + ip2 -= ib2_n; + op += os_p; + ip2 += is2_p; + } + op -= ob_p; + ip2 -= ib2_p; + ip1 += is1_m; + op += os_m; + } +} + /**begin repeat * #TYPE = FLOAT, DOUBLE, LONGDOUBLE, HALF, * CFLOAT, CDOUBLE, CLONGDOUBLE, * UBYTE, USHORT, UINT, ULONG, ULONGLONG, * BYTE, SHORT, INT, LONG, LONGLONG, - * BOOL# + * BOOL, OBJECT# * #typ = npy_float,npy_double,npy_longdouble, npy_half, * npy_cfloat, npy_cdouble, npy_clongdouble, * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong, * npy_byte, npy_short, npy_int, npy_long, npy_longlong, - * npy_bool# - * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*11# - * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*12# + * npy_bool,npy_object# + * #IS_COMPLEX = 0, 0, 0, 0, 1, 1, 1, 0*12# + * #USEBLAS = 1, 1, 0, 0, 1, 1, 0*13# */ @@ -398,5 +501,3 @@ NPY_NO_EXPORT void } /**end repeat**/ - - diff --git a/numpy/core/src/umath/matmul.h.src b/numpy/core/src/umath/matmul.h.src index 16be7675b945..a664b1b4e1fd 100644 --- a/numpy/core/src/umath/matmul.h.src +++ b/numpy/core/src/umath/matmul.h.src @@ -3,7 +3,7 @@ * CFLOAT, CDOUBLE, CLONGDOUBLE, * UBYTE, USHORT, UINT, ULONG, ULONGLONG, * BYTE, SHORT, INT, LONG, LONGLONG, - * BOOL# + * BOOL, OBJECT# **/ NPY_NO_EXPORT void @TYPE@_matmul(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func)); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 873aa9312b59..77e2567a98c4 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -6067,7 +6067,69 @@ def test_dot_equivalent(self, args): r3 = np.matmul(args[0].copy(), args[1].copy()) assert_equal(r1, r3) - + + def test_matmul_object(self): + import fractions + + f = np.vectorize(fractions.Fraction) + def random_ints(): + return np.random.randint(1, 1000, size=(10, 3, 3)) + M1 = f(random_ints(), random_ints()) + M2 = f(random_ints(), random_ints()) + + M3 = self.matmul(M1, M2) + + [N1, N2, N3] = [a.astype(float) for a in [M1, M2, M3]] + + assert_allclose(N3, self.matmul(N1, N2)) + + def test_matmul_object_type_scalar(self): + from fractions import Fraction as F + v = np.array([F(2,3), F(5,7)]) + res = self.matmul(v, v) + assert_(type(res) is F) + + def test_matmul_empty(self): + a = np.empty((3, 0), dtype=object) + b = np.empty((0, 3), dtype=object) + c = np.zeros((3, 3)) + assert_array_equal(np.matmul(a, b), c) + + def test_matmul_exception_multiply(self): + # test that matmul fails if `__mul__` is missing + class add_not_multiply(): + def __add__(self, other): + return self + a = np.full((3,3), add_not_multiply()) + with assert_raises(TypeError): + b = np.matmul(a, a) + + def test_matmul_exception_add(self): + # test that matmul fails if `__add__` is missing + class multiply_not_add(): + def __mul__(self, other): + return self + a = np.full((3,3), multiply_not_add()) + with assert_raises(TypeError): + b = np.matmul(a, a) + + def test_matmul_bool(self): + # gh-14439 + a = np.array([[1, 0],[1, 1]], dtype=bool) + assert np.max(a.view(np.uint8)) == 1 + b = np.matmul(a, a) + # matmul with boolean output should always be 0, 1 + assert np.max(b.view(np.uint8)) == 1 + + np.random.seed(42) + d = np.random.randint(2, size=4*5, dtype=np.int8) + d = d.reshape(4, 5) > 0 + out1 = np.matmul(d, d.reshape(5, 4)) + out2 = np.dot(d, d.reshape(5, 4)) + assert_equal(out1, out2) + + c = np.matmul(np.zeros((2, 0), dtype=bool), np.zeros(0, dtype=bool)) + assert not np.any(c) if sys.version_info[:2] >= (3, 5):