8000 ENH: do integer**2. inplace · juliantaylor/numpy@9a2f0fe · GitHub
[go: up one dir, main page]

Skip to content

Commit 9a2f0fe

Browse files
committed
ENH: do integer**2. inplace
Squaring integer arrays with float argument always casts to double so the squaring itself can be done inplace. Also use square fastop instead of multiply.
1 parent d21cecf commit 9a2f0fe

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

numpy/core/src/multiarray/number.c

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -583,30 +583,28 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
583583
* (thus, the input should be up-cast)
584584
*/
585585
else if (exponent == 2.0) {
586-
fastop = n_ops.multiply;
586+
fastop = n_ops.square;
587587
if (inplace) {
588-
return PyArray_GenericInplaceBinaryFunction
589-
(a1, (PyObject *)a1, fastop);
588+
return PyArray_GenericInplaceUnaryFunction(a1, fastop);
590589
}
591590
else {
592-
PyArray_Descr *dtype = NULL;
593-
PyObject *res;
594-
595591
/* We only special-case the FLOAT_SCALAR and integer types */
596592
if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) {
597-
dtype = PyArray_DescrFromType(NPY_DOUBLE);
593+
PyObject *res;
594+
PyArray_Descr *dtype = PyArray_DescrFromType(NPY_DOUBLE);
598595
a1 = (PyArrayObject *)PyArray_CastToType(a1, dtype,
599596
PyArray_ISFORTRAN(a1));
600597
if (a1 == NULL) {
601598
return NULL;
602599
}
600+
/* cast always creates a new array */
601+
res = PyArray_GenericInplaceUnaryFunction(a1, fastop);
602+
Py_DECREF(a1);
603+
return res;
603604
}
604605
else {
605-
Py_INCREF(a1);
606+
return PyArray_GenericUnaryFunction(a1, fastop);
606607
}
607-
res = PyArray_GenericBinaryFunction(a1, (PyObject *)a1, fastop);
608-
Py_DECREF(a1);
609-
return res;
610608
}
611609
}
612610
}

numpy/core/tests/test_umath.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,12 @@ def assert_complex_equal(x, y):
456456

457457
def test_fast_power(self):
458458
x = np.array([1, 2, 3], np.int16)
459-
assert_((x**2.00001).dtype is (x**2.0).dtype)
459+
res = x**2.0
460+
assert_((x**2.00001).dtype is res.dtype)
461+
assert_array_equal(res, [1, 4, 9])
462+
# check the inplace operation on the casted copy doesn't mess with x
463+
assert_(not np.may_share_memory(res, x))
464+
assert_array_equal(x, [1, 2, 3])
460465

461466
# Check that the fast path ignores 1-element not 0-d arrays
462467
res = x ** np.array([[[2]]])

0 commit comments

Comments
 (0)
0