8000 Merge pull request #5242 from juliantaylor/fix-ufunc-subok-out · numpy/numpy@994e98c · GitHub
[go: up one dir, main page]

Skip to content

Commit 994e98c

Browse files
committed
Merge pull request #5242 from juliantaylor/fix-ufunc-subok-out
BUG: fix not returning out array from ufuncs with subok=False set
2 parents c88fd91 + b40e686 commit 994e98c

File tree

4 files changed

+55
-6
lines changed

4 files changed

+55
-6
lines changed

numpy/core/src/umath/ufunc_object.c

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3932,18 +3932,19 @@ _find_array_wrap(PyObject *args, PyObject *kwds,
39323932
PyObject *with_wrap[NPY_MAXARGS], *wraps[NPY_MAXARGS];
39333933
PyObject *obj, *wrap = NULL;
39343934

3935-
/* If a 'subok' parameter is passed and isn't True, don't wrap */
3935+
/*
3936+
* If a 'subok' parameter is passed and isn't True, don't wrap but put None
3937+
* into slots with out arguments which means return the out argument
3938+
*/
39363939
if (kwds != NULL && (obj = PyDict_GetItem(kwds,
39373940
npy_um_str_subok)) != NULL) {
39383941
if (obj != Py_True) {
3939-
for (i = 0; i < nout; i++) {
3940-
output_wrap[i] = NULL;
3941-
}
3942-
return;
3942+
/* skip search for wrap members */
3943+
goto handle_out;
39433944
}
39443945
}
39453946

3946-
nargs = PyTuple_GET_SIZE(args);
3947+
39473948
for (i = 0; i < nin; i++) {
39483949
obj = PyTuple_GET_ITEM(args, i);
39493950
if (PyArray_CheckExact(obj) || PyArray_IsAnyScalar(obj)) {
@@ -4001,6 +4002,8 @@ _find_array_wrap(PyObject *args, PyObject *kwds,
40014002
* exact ndarray so that no PyArray_Return is
40024003
* done in that case.
40034004
*/
4005+
handle_out:
4006+
nargs = PyTuple_GET_SIZE(args);
40044007
for (i = 0; i < nout; i++) {
40054008
int j = nin + i;
40064009
int incref = 1;

numpy/core/tests/test_numeric.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,6 +1697,20 @@ def test_ddof2(self):
16971697
assert_almost_equal(std(self.A, ddof=2)**2,
16981698
self.real_var*len(self.A)/float(len(self.A)-2))
16991699

1700+
def test_out_scalar(self):
1701+
d = np.arange(10)
1702+
out = np.array(0.)
1703+
r = np.std(d, out=out)
1704+
assert_(r is out)
1705+
assert_array_equal(r, out)
1706+
r = np.var(d, out=out)
1707+
assert_(r is out)
1708+
assert_array_equal(r, out)
1709+
r = np.mean(d, out=out)
1710+
assert_(r is out)
1711+
assert_array_equal(r, out)
1712+
1713+
17001714
class TestStdVarComplex(TestCase):
17011715
def test_basic(self):
17021716
A = array([1, 1.j, -1, -1.j])

numpy/core/tests/test_umath.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,36 @@ def test_e(self):
3636
def test_euler_gamma(self):
3737
assert_allclose(ncu.euler_gamma, 0.5772156649015329, 1e-15)
3838

39+
class TestOut(TestCase):
40+
def test_out_subok(self):
41+
for b in (True, False):
42+
aout = np.array(0.5)
43+
44+
r = np.add(aout, 2, out=aout)
45+
assert_(r is aout)
46+
assert_array_equal(r, aout)
47+
48+
r = np.add(aout, 2, out=aout, subok=b)
49+
assert_(r is aout)
50+
assert_array_equal(r, aout)
51+
52+
r = np.add(aout, 2, aout, subok=False)
53+
assert_(r is aout)
54+
assert_array_equal(r, aout)
55+
56+
d = np.ones(5)
57+
o1 = np.zeros(5)
58+
o2 = np.zeros(5, dtype=np.int32)
59+
r1, r2 = np.frexp(d, o1, o2, subok=b)
60+
assert_(r1 is o1)
61+
assert_array_equal(r1, o1)
62+
assert_(r2 is o2)
63+
assert_array_equal(r2, o2)
64+
65+
r1, r2 = np.frexp(d, out=o1, subok=b)
66+
assert_(r1 is o1)
67+
assert_array_equal(r1, o1)
68+
3969

4070
class TestDivision(TestCase):
4171
def test_division_int(self):

numpy/ma/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,8 @@ def __call__ (self, a, b):
780780
# component of numpy's import time.
781781
if self.tolerance is None:
782782
self.tolerance = np.finfo(float).tiny
783+
# don't call ma ufuncs from __array_wrap__ which would fail for scalars
784+
a, b = np.asarray(a), np.asarray(b)
783785
return umath.absolute(a) * self.tolerance >= umath.absolute(b)
784786

785787

0 commit comments

Comments
 (0)
0