8000 Merge pull request #4753 from mhvk/bug-4753 · numpy/numpy@68e61c2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 68e61c2

Browse files
committed
Merge pull request #4753 from mhvk/bug-4753
BUG difference in behaviour for subclass output in ufuncs
2 parents c8ca8ab + 288239d commit 68e61c2

File tree

2 files changed

+97
-12
lines changed

2 files changed

+97
-12
lines changed

numpy/core/src/private/ufunc_override.h

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,13 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
180180
int override_pos; /* Position of override in args.*/
181181
int j;
182182

183-
int nargs = PyTuple_GET_SIZE(args);
183+
int nargs;
184+
int nout_kwd = 0;
185+
int out_kwd_is_tuple = 0;
184186
int noa = 0; /* Number of overriding args.*/
185187

186188
PyObject *obj;
189+
PyObject *out_kwd_obj = NULL;
187190
PyObject *other_obj;
188191

189192
PyObject *method_name = NULL;
@@ -204,16 +207,40 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
204207
"with non-tuple");
205208
goto fail;
206209
}
207-
208-
if (PyTuple_GET_SIZE(args) > NPY_MAXARGS) {
210+
nargs = PyTuple_GET_SIZE(args);
211+
if (nargs > NPY_MAXARGS) {
209212
PyErr_SetString(PyExc_ValueError,
210213
"Internal Numpy error: too many arguments in call "
211214
"to PyUFunc_CheckOverride");
212215
goto fail;
213216
}
214217

215-
for (i = 0; i < nargs; ++i) {
216-
obj = PyTuple_GET_ITEM(args, i);
218+
/* be sure to include possible 'out' keyword argument. */
219+
if ((kwds)&& (PyDict_CheckExact(kwds))) {
220+
out_kwd_obj = PyDict_GetItemString(kwds, "out");
221+
if (out_kwd_obj != NULL) {
222+
out_kwd_is_tuple = PyTuple_CheckExact(out_kwd_obj);
223+
if (out_kwd_is_tuple) {
224+
nout_kwd = PyTuple_GET_SIZE(out_kwd_obj);
225+
}
226+
else {
227+
nout_kwd = 1;
228+
}
229+
}
230+
}
231+
232+
for (i = 0; i < nargs + nout_kwd; ++i) {
233+
if (i < nargs) {
234+
obj = PyTuple_GET_ITEM(args, i);
235+
}
236+
else {
237+
if (out_kwd_is_tuple) {
238+
obj = PyTuple_GET_ITEM(out_kwd_obj, i-nargs);
239+
}
240+
else {
241+
obj = out_kwd_obj;
242+
}
243+
}
217244
/*
218245
* TODO: could use PyArray_GetAttrString_SuppressException if it
219246
* weren't private to multiarray.so

numpy/core/tests/test_multiarray.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2428,15 +2428,15 @@ def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
24282428
return "ufunc"
24292429
else:
24302430
inputs = list(inputs)
2431-
inputs[i] = np.asarray(self)
2431+
if i < len(inputs):
2432+
inputs[i] = np.asarray(self)
24322433
func = getattr(ufunc, method)
2434+
if ('out' in kw) and (kw['out'] is not None):
2435+
kw['out'] = np.asarray(kw['out'])
24332436
r = func(*inputs, **kw)
2434-
if 'out' in kw:
2435-
return r
2436-
else:
2437-
x = self.__class__(r.shape, dtype=r.dtype)
2438-
x[...] = r
2439-
return x
2437+
x = self.__class__(r.shape, dtype=r.dtype)
2438+
x[...] = r
2439+
return x
24402440

24412441
class SomeClass3(SomeClass2):
24422442
def __rsub__(self, other):
@@ -2520,6 +2520,64 @@ def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
25202520
assert_('sig' not in kw and 'signature' in kw)
25212521
assert_equal(kw['signature'], 'ii->i')
25222522

2523+
def test_numpy_ufunc_index(self):
2524+
# Check that index is set appropriately, also if only an output
2525+
# is passed on (latter is another regression tests for github bug 4753)
2526+
class CheckIndex(object):
2527+
def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
2528+
return i
2529+
2530+
a = CheckIndex()
2531+
dummy = np.arange(2.)
2532+
# 1 input, 1 output
2533+
assert_equal(np.sin(a), 0)
2534+
assert_equal(np.sin(dummy, a), 1)
2535+
assert_equal(np.sin(dummy, out=a), 1)
2536+
assert_equal(np.sin(dummy, out=(a,)), 1)
2537+
assert_equal(np.sin(a, a), 0)
2538+
assert_equal(np.sin(a, out=a), 0)
2539+
assert_equal(np.sin(a, out=(a,)), 0)
2540+
# 1 input, 2 outputs
2541+
assert_equal(np.modf(dummy, a), 1)
2542+
assert_equal(np.modf(dummy, None, a), 2)
2543+
assert_equal(np.modf(dummy, dummy, a), 2)
2544+
assert_equal(np.modf(dummy, out=a), 1)
2545+
assert_equal(np.modf(dummy, out=(a,)), 1)
2546+
assert_equal(np.modf(dummy, out=(a, None)), 1)
2547+
assert_equal(np.modf(dummy, out=(a, dummy)), 1)
2548+
assert_equal(np.modf(dummy, out=(None, a)), 2)
2549+
assert_equal(np.modf(dummy, out=(dummy, a)), 2)
2550+
assert_equal(np.modf(a, out=(dummy, a)), 0)
2551+
# 2 inputs, 1 output
2552+
assert_equal(np.add(a, dummy), 0)
2553+
assert_equal(np.add(dummy, a), 1)
2554+
assert_equal(np.add(dummy, dummy, a), 2)
2555+
assert_equal(np.add(dummy, a, a), 1)
2556+
assert_equal(np.add(dummy, dummy, out=a), 2)
2557+
assert_equal(np.add(dummy, dummy, out=(a,)), 2)
2558+
assert_equal(np.add(a, dummy, out=a), 0)
2559+
2560+
def test_out_override(self):
2561+
# regression test for github bug 4753
2562+
class OutClass(ndarray):
2563+
def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
2564+
if 'out' in kw:
2565+
tmp_kw = kw.copy()
2566+
tmp_kw.pop('out')
2567+
func = getattr(ufunc, method)
2568+
kw['out'][...] = func(*inputs, **tmp_kw)
2569+
2570+
A = np.array([0]).view(OutClass)
2571+
B = np.array([5])
2572+
C = np.array([6])
2573+
np.multiply(C, B, A)
2574+
assert_equal(A[0], 30)
2575+
assert_(isinstance(A, OutClass))
2576+
A[0] = 0
2577+
np.multiply(C, B, out=A)
2578+
assert_equal(A[0], 30)
2579+
assert_(isinstance(A, OutClass))
2580+
25232581

25242582
class TestCAPI(TestCase):
25252583
def test_IsPythonScalar(self):

0 commit comments

Comments
 (0)
0