diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index 3d5e51ba92e2..08505c5c7e50 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -264,10 +264,20 @@ gentype_@name@(PyObject *m1, PyObject *m2) static PyObject * gentype_multiply(PyObject *m1, PyObject *m2) { - PyObject *ret = NULL; npy_intp repeat; + /* + * If the other object supports sequence repeat and not number multiply + * we should call sequence repeat to support e.g. list repeat by numpy + * scalars (they may be converted to ndarray otherwise). + * A python defined class will always only have the nb_multiply slot and + * some classes may have neither defined. For the latter we want need + * to give the normal case a chance to convert the object to ndarray. + * Probably no class has both defined, but if they do, prefer number. + */ if (!PyArray_IsScalar(m1, Generic) && + ((Py_TYPE(m1)->tp_as_sequence != NULL) && + (Py_TYPE(m1)->tp_as_sequence->sq_repeat != NULL)) && ((Py_TYPE(m1)->tp_as_number == NULL) || (Py_TYPE(m1)->tp_as_number->nb_multiply == NULL))) { /* Try to convert m2 to an int and try sequence repeat */ @@ -276,9 +286,11 @@ gentype_multiply(PyObject *m1, PyObject *m2) return NULL; } /* Note that npy_intp is compatible to Py_Ssize_t */ - ret = PySequence_Repeat(m1, repeat); + return PySequence_Repeat(m1, repeat); } - else if (!PyArray_IsScalar(m2, Generic) && + if (!PyArray_IsScalar(m2, Generic) && + ((Py_TYPE(m2)->tp_as_sequence != NULL) && + (Py_TYPE(m2)->tp_as_sequence->sq_repeat != NULL)) && ((Py_TYPE(m2)->tp_as_number == NULL) || (Py_TYPE(m2)->tp_as_number->nb_multiply == NULL))) { /* Try to convert m1 to an int and try sequence repeat */ @@ -286,13 +298,11 @@ gentype_multiply(PyObject *m1, PyObject *m2) if (repeat == -1 && PyErr_Occurred()) { return NULL; } - ret = PySequence_Repeat(m2, repeat); - } - if (ret == NULL) { - PyErr_Clear(); /* no effect if not set */ - ret = PyArray_Type.tp_as_number->nb_multiply(m1, m2); + return PySequence_Repeat(m2, repeat); } - return ret; + + /* All normal cases are handled by PyArray's multiply */ + return PyArray_Type.tp_as_number->nb_multiply(m1, m2); } /**begin repeat diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py index 12b1a0fe3356..62faa763f857 100644 --- a/numpy/core/tests/test_scalarmath.py +++ b/numpy/core/tests/test_scalarmath.py @@ -9,7 +9,7 @@ from numpy.testing.utils import _gen_alignment_data from numpy.testing import ( TestCase, run_module_suite, assert_, assert_equal, assert_raises, - assert_almost_equal, assert_allclose + assert_almost_equal, assert_allclose, assert_array_equal ) types = [np.bool_, np.byte, np.ubyte, np.short, np.ushort, np.intc, np.uintc, @@ -448,6 +448,44 @@ def test_error(self): assert_raises(TypeError, d.__sizeof__, "a") +class TestMultiply(TestCase): + def test_seq_repeat(self): + # Test that basic sequences get repeated when multiplied with + # numpy integers. And errors are raised when multiplied with others. + # Some of this behaviour may be controversial and could be open for + # change. + for seq_type in (list, tuple): + seq = seq_type([1, 2, 3]) + for numpy_type in np.typecodes["AllInteger"]: + i = np.dtype(numpy_type).type(2) + assert_equal(seq * i, seq * int(i)) + assert_equal(i * seq, int(i) * seq) + + for numpy_type in np.typecodes["All"].replace("V", ""): + if numpy_type in np.typecodes["AllInteger"]: + continue + i = np.dtype(numpy_type).type() + assert_raises(TypeError, operator.mul, seq, i) + assert_raises(TypeError, operator.mul, i, seq) + + def test_no_seq_repeat_basic_array_like(self): + # Test that an array-like which does not know how to be multiplied + # does not attempt sequence repeat (raise TypeError). + # See also gh-7428. + class ArrayLike(object): + def __init__(self, arr): + self.arr = arr + def __array__(self): + return self.arr + + # Test for simple ArrayLike above and memoryviews (original report) + for arr_like in (ArrayLike(np.ones(3)), memoryview(np.ones(3))): + assert_array_equal(arr_like * np.float32(3.), np.full(3, 3.)) + assert_array_equal(np.float32(3.) * arr_like, np.full(3, 3.)) + assert_array_equal(arr_like * np.int_(3), np.full(3, 3)) + assert_array_equal(np.int_(3) * arr_like, np.full(3, 3)) + + class TestAbs(TestCase): def _test_abs_func(self, absfunc):