8000 Merge pull request #9129 from mhvk/array_ufunc_fast_scalar_power_back… · numpy/numpy@1b53503 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1b53503

Browse files
authored
Merge pull request #9129 from mhvk/array_ufunc_fast_scalar_power_backport
BUG: ndarray.__pow__ does not check result of fast_scalar_power
2 parents 7c6ace1 + f652cdf commit 1b53503

File tree

4 files changed

+61
-38
lines changed
  • doc/neps
  • numpy/core
    • < 10000 div style="width:100%;display:flex">
      src/multiarray
  • tests
  • 4 files changed

    +61
    -38
    lines changed

    doc/neps/ufunc-overrides.rst

    Lines changed: 11 additions & 5 deletions
    Original file line numberDiff line numberDiff line change
    @@ -664,13 +664,13 @@ Symbol Operator NumPy Ufunc(s)
    664664
    ``//`` ``floordiv`` :func:`floor_divide`
    665665
    ``%`` ``mod`` :func:`remainder`
    666666
    NA ``divmod`` :func:`divmod`
    667-
    ``**`` ``pow`` :func:`power`
    667+
    ``**`` ``pow`` :func:`power` [10]_
    668668
    ``<<`` ``lshift`` :func:`left_shift`
    669669
    ``>>`` ``rshift`` :func:`right_shift`
    670670
    ``&`` ``and_`` :func:`bitwise_and`
    671671
    ``^`` ``xor_`` :func:`bitwise_xor`
    672672
    ``|`` ``or_`` :func:`bitwise_or`
    673-
    ``@`` ``matmul`` Not yet implemented as a ufunc [10]_
    673+
    ``@`` ``matmul`` Not yet implemented as a ufunc [11]_
    674674
    ====== ============ =========================================
    675675

    676676
    And here is the list of unary operators:
    @@ -679,16 +679,22 @@ And here is the list of unary operators:
    679679
    Symbol Operator NumPy Ufunc(s)
    680680
    ====== ============ =========================================
    681681
    ``-`` ``neg`` :func:`negative`
    682-
    ``+`` ``pos`` :func:`positive` [11]_
    682+
    ``+`` ``pos`` :func:`positive` [12]_
    683683
    NA ``abs`` :func:`absolute`
    684684
    ``~`` ``invert`` :func:`invert`
    685685
    ====== ============ =========================================
    686686

    687-
    .. [10] Because NumPy's :func:`matmul` is not a ufunc, it is
    687+
    .. [10] class :`ndarray` takes short cuts for ``__pow__`` for the
    688+
    cases where the power equals ``1`` (:func:`positive`),
    689+
    ``-1`` (:func:`reciprocal`), ``2`` (:func:`square`), ``0`` (an
    690+
    otherwise private ``_ones_like`` ufunc), and ``0.5``
    691+
    (:func:`sqrt`), and the array is float or complex (or integer
    692+
    for square).
    693+
    .. [11] Because NumPy's :func:`matmul` is not a ufunc, it is
    688694
    `currently not possible <https://github.com/numpy/numpy/issues/9028>`_
    689695
    to override ``numpy_array @ other`` with ``other`` taking precedence
    690696
    if ``other`` implements ``__array_func__``.
    691-
    .. [11] :class:`ndarray` currently does a copy instead of using this ufunc.
    697+
    .. [12] :class:`ndarray` currently does a copy instead of using this ufunc.
    692698
    693699
    Future extensions to other functions
    694700
    ------------------------------------

    numpy/core/src/multiarray/number.c

    Lines changed: 28 additions & 33 deletions
    Original file line numberDiff line numberDiff line change
    @@ -91,6 +91,7 @@ PyArray_SetNumericOps(PyObject *dict)
    9191
    SET(sqrt);
    9292
    SET(cbrt);
    9393
    SET(negative);
    94+
    SET(positive);
    9495
    SET(absolute);
    9596
    SET(invert);
    9697
    SET(left_shift);
    @@ -143,6 +144,7 @@ PyArray_GetNumericOps(void)
    143144
    GET(_ones_like);
    144145
    GET(sqrt);
    145146
    GET(negative);
    147+
    GET(positive);
    146148
    GET(absolute);
    147149
    GET(invert);
    148150
    GET(left_shift);
    @@ -453,9 +455,14 @@ is_scalar_with_conversion(PyObject *o2, double* out_exponent)
    453455
    return NPY_NOSCALAR;
    454456
    }
    455457

    456-
    /* optimize float array or complex array to a scalar power */
    457-
    static PyObject *
    458-
    fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
    458+
    /*
    459+
    * optimize float array or complex array to a scalar power
    460+
    * returns 0 on success, -1 if no optimization is possible
    461+
    * the result is in value (can be NULL if an error occurred)
    462+
    */
    463+
    static int
    464+
    fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace,
    465+
    PyObject **value)
    459466
    {
    460467
    double exponent;
    461468
    NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */
    @@ -464,17 +471,7 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
    464471
    PyObject *fastop = NULL;
    465472
    if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) {
    466473
    if (exponent == 1.0) {
    467-
    /* we have to do this one special, as the
    468-
    "copy" method of array objects isn't set
    469-
    up early enough to be added
    470-
    by PyArray_SetNumericOps.
    471-
    */
    472-
    if (inplace) {
    473-
    Py_INCREF(a1);
    474-
    return (PyObject *)a1;
    475-
    } else {
    476-
    return PyArray_Copy(a1);
    477-
    }
    474+
    fastop = n_ops.positive;
    478475
    }
    479476
    else if (exponent == -1.0) {
    480477
    fastop = n_ops.reciprocal;
    @@ -489,15 +486,16 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
    489486
    fastop = n_ops.square;
    490487
    }
    491488
    else {
    492-
    return NULL;
    489+
    return -1;
    493490
    }
    494491

    495492
    if (inplace || can_elide_temp_unary(a1)) {
    496-
    return PyArray_GenericInplaceUnaryFunction(a1, fastop);
    493+
    *value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
    497494
    }
    498495
    else {
    499-
    return PyArray_GenericUnaryFunction(a1, fastop);
    496+
    *value = PyArray_GenericUnaryFunction(a1, fastop);
    500497
    }
    498+
    return 0;
    501499
    }
    502500
    /* Because this is called with all arrays, we need to
    503501
    * change the output if the kind of the scalar is different
    @@ -507,36 +505,35 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
    507505
    else if (exponent == 2.0) {
    508506
    fastop = n_ops.square;
    509507
    if (inplace) {
    510-
    return PyArray_GenericInplaceUnaryFunction(a1, fastop);
    508+
    *value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
    511509
    }
    512510
    else {
    513511
    /* We only special-case the FLOAT_SCALAR and integer types */
    514512
    if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) {
    515-
    PyObject *res;
    516513
    PyArray_Descr *dtype = PyArray_DescrFromType(NPY_DOUBLE);
    517514
    a1 = (PyArrayObject *)PyArray_CastToType(a1, dtype,
    518515
    PyArray_ISFORTRAN(a1));
    519-
    if (a1 == NULL) {
    520-
    return NULL;
    516+
    if (a1 != NULL) {
    517+
    /* cast always creates a new array */
    518+
    *value = PyArray_GenericInplaceUnaryFunction(a1, fastop);
    519+
    Py_DECREF(a1);
    521520
    }
    522-
    /* cast always creates a new array */
    523-
    res = PyArray_GenericInplaceUnaryFunction(a1, fastop);
    524-
    Py_DECREF(a1);
    525-
    return res;
    526521
    }
    527522
    else {
    528-
    return PyArray_GenericUnaryFunction(a1, fastop);
    523+
    *value = PyArray_GenericUnaryFunction(a1, fastop);
    529524
    }
    530525
    }
    526+
    return 0;
    531527
    }
    532528
    }
    533-
    return NULL;
    529+
    /* no fast operation found */
    530+
    return -1;
    534531
    }
    535532

    536533
    static PyObject *
    537534
    array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo)
    538535
    {
    539-
    PyObject *value;
    536+
    PyObject *value = NULL;
    540537

    541538
    if (modulo != Py_None) {
    542539
    /* modular exponentiation is not implemented (gh-8804) */
    @@ -545,8 +542,7 @@ array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo)
    545542
    }
    546543

    547544
    BINOP_GIVE_UP_IF_NEEDED(a1, o2, nb_power, array_power);
    548-
    value = fast_scalar_power(a1, o2, 0);
    549-
    if (!value) {
    545+
    if (fast_scalar_power(a1, o2, 0, &value) != 0) {
    550546
    value = PyArray_GenericBinaryFunction(a1, o2, n_ops.power);
    551547
    }
    552548
    return value;
    @@ -686,12 +682,11 @@ static PyObject *
    686682
    array_inplace_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo))
    687683
    {
    688684
    /* modulo is ignored! */
    689-
    PyObject *value;
    685+
    PyObject *value = NULL;
    690686

    691687
    INPLACE_GIVE_UP_IF_NEEDED(
    692688
    a1, o2, nb_inplace_power, array_inplace_power);
    693-
    value = fast_scalar_power(a1, o2, 1);
    694-
    if (!value) {
    689+
    if (fast_scalar_power(a1, o2, 1, &value) != 0) {
    695690
    value = PyArray_GenericInplaceBinaryFunction(a1, o2, n_ops.power);
    696691
    }
    697692
    return value;

    numpy/core/src/multiarray/number.h

    Lines changed: 1 addition & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -15,6 +15,7 @@ typedef struct {
    1515
    PyObject *sqrt;
    1616
    PyObject *cbrt;
    1717
    PyObject *negative;
    18+
    PyObject *positive;
    1819
    PyObject *absolute;
    1920
    PyObject *invert;
    2021
    PyObject *left_shift;

    numpy/core/tests/test_multiarray.py

    Lines changed: 21 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -3069,6 +3069,27 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kw):
    30693069
    assert_equal(A[0], 30)
    30703070
    assert_(isinstance(A, OutClass))
    30713071

    3072+
    def test_pow_override_with_errors(self):
    3073+
    # regression test for gh-9112
    3074+
    class PowerOnly(np.ndarray):
    3075+
    def __array_ufunc__(self, ufunc, method, *inputs, **kw):
    3076+
    if ufunc is not np.power:
    3077+
    raise NotImplementedError
    3078+
    return "POWER!"
    3079+
    # explicit cast to float, to ensure the fast power path is taken.
    3080+
    a = np.array(5., dtype=np.float64).view(PowerOnly)
    3081+
    assert_equal(a ** 2.5, "POWER!")
    3082+
    with assert_raises(NotImplementedError):
    3083+
    a ** 0.5
    3084+
    with assert_raises(NotImplementedError):
    3085+
    a ** 0
    3086+
    with assert_raises(NotImplementedError):
    3087+
    a ** 1
    3088+
    with assert_raises(NotImplementedError):
    3089+
    a ** -1
    3090+
    with assert_raises(NotImplementedError):
    3091+
    a ** 2
    3092+
    30723093

    30733094
    class TestTemporaryElide(TestCase):
    30743095
    # elision is only triggered on relatively large arrays

    0 commit comments

    Comments
     (0)
    0