8000 Check `out` kwarg for __nump_ufunc__ override and set index appropria… · numpy/numpy@288239d · GitHub
[go: up one dir, main page]

Skip to content

Commit 288239d

Browse files
cowlicksmhvk
authored andcommitted
Check out kwarg for __nump_ufunc__ override and set index appropriately
for the case where self is among outputs but not among inputs. Ensure it works both out passed on as an argument and with out in a keyword argument.
1 parent d033b6e commit 288239d

File tree

2 files changed
+97
-12
lines changed

2 files changed

+97
-12
lines changed

numpy/core/src/private/ufunc_override.h

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,13 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
180180
int override_pos; /* Position of override in args.*/
181181
int j;
182182

183-
int nargs = PyTuple_GET_SIZE(args);
183+
int nargs;
184+
int nout_kwd = 0;
185+
int out_kwd_is_tuple = 0;
184186
int noa = 0; /* Number of overriding args.*/
185187

186188
PyObject *obj;
189+
PyObject *out_kwd_obj = NULL;
187190
PyObject *other_obj;
188191

189192
PyObject *method_name = NULL;
@@ -204,16 +207,40 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
204207
"with non-tuple");
205208
goto fail;
206209
}
207-
208-
if (PyTuple_GET_SIZE(args) > NPY_MAXARGS) {
210+
nargs = PyTuple_GET_SIZE(args);
211+
if (nargs > NPY_MAXARGS) {
209212
PyErr_SetString(PyExc_ValueError,
210213
"Internal Numpy error: too many arguments in call "
211214
"to PyUFunc_CheckOverride");
212215
goto fail;
213216
}
214217

215-
for (i = 0; i < nargs; ++i) {
216-
obj = PyTuple_GET_ITEM(args, i);
218+
/* be sure to include possible 'out' keyword argument. */
219+
if ((kwds)&& (PyDict_CheckExact(kwds))) {
220+
out_kwd_obj = PyDict_GetItemString(kwds, "out");
221+
if (out_kwd_obj != NULL) {
222+
out_kwd_is_tuple = PyTuple_CheckExact(out_kwd_obj);
223+
if (out_kwd_is_tuple) {
224+
nout_kwd = PyTuple_GET_SIZE(out_kwd_obj);
225+
}
226+
else {
227+
nout_kwd = 1;
228+
}
229+
}
230+
}
231+
232+
for (i = 0; i < nargs + nout_kwd; ++i) {
233+
if (i < nargs) {
234+
obj = PyTuple_GET_ITEM(args, i);
235+
}
236+
else {
237+
if (out_kwd_is_tuple) {
238+
obj = PyTuple_GET_ITEM(out_kwd_obj, i-nargs);
239+
}
240+
else {
241+
obj = out_kwd_obj;
242+
}
243+
}
217244
/*
218245
* TODO: could use PyArray_GetAttrString_SuppressException if it
219246
* weren't private to multiarray.so

numpy/core/tests/test_multiarray.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,15 +2383,15 @@ def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
23832383
return "ufunc"
23842384
else:
23852385
inputs = list(inputs)
2386-
inputs[i] = np.asarray(self)
2386+
if i < len(inputs):
2387+
inputs[i] = np.asarray(self)
23872388
func = getattr(ufunc, method)
2389+
if ('out' in kw) and (kw['out'] is not None):
2390+
kw['out'] = np.asarray(kw['out'])
23882391
r = func(*inputs, **kw)
2389-
if 'out' in kw:
2390-
return r
2391-
else:
2392-
x = self.__class__(r.shape, dtype=r.dtype)
2393-
x[...] = r
2394-
return x
2392+
x = self.__class__(r.shape, dtype=r.dtype)
2393+
x[...] = r
2394+
return x
23952395

23962396
class SomeClass3(SomeClass2):
23972397
def __rsub__(self, other):
@@ -2475,6 +2475,64 @@ def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
24752475
assert_('sig' not in kw and 'signature' in kw)
24762476
assert_equal(kw['signature'], 'ii->i')
24772477

2478+
def test_numpy_ufunc_index(self):
2479+
# Check that index is set appropriately, also if only an output
2480+
# is passed on (latter is another regression tests for github bug 4753)
2481+
class CheckIndex(object):
2482+
def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
2483+
return i
2484+
2485+
a = CheckIndex()
2486+
dummy = np.arange(2.)
2487+
# 1 input, 1 output
2488+
assert_equal(np.sin(a), 0)
2489+
assert_equal(np.sin(dummy, a), 1)
2490+
assert_equal(np.sin(dummy, out=a), 1)
2491+
assert_equal(np.sin(dummy, out=(a,)), 1)
2492+
assert_equal(np.sin(a, a), 0)
2493+
assert_equal(np.sin(a, out=a), 0)
2494+
assert_equal(np.sin(a, out=(a,)), 0)
2495+
# 1 input, 2 outputs
2496+
assert_equal(np.modf(dummy, a), 1)
2497+
assert_equal(np.modf(dummy, None, a), 2)
2498+
assert_equal(np.modf(dummy, dummy, a), 2)
2499+
assert_equal(np.modf(dummy, out=a), 1)
2500+
assert_equal(np.modf(dummy, out=(a,)), 1)
2501+
assert_equal(np.modf(dummy, out=(a, None)), 1)
2502+
assert_equal(np.modf(dummy, out=(a, dummy)), 1)
2503+
assert_equal(np.modf(dummy, out=(None, a)), 2)
2504+
assert_equal(np.modf(dummy, out=(dummy, a)), 2)
2505+
assert_equal(np.modf(a, out=(dummy, a)), 0)
2506+
# 2 inputs, 1 output
2507+
assert_equal(np.add(a, dummy), 0)
2508+
assert_equal(np.add(dummy, a), 1)
2509+
assert_equal(np.add(dummy, dummy, a), 2)
2510+
assert_equal(np.add(dummy, a, a), 1)
2511+
assert_equal(np.add(dummy, dummy, out=a), 2)
2512+
assert_equal(np.add(dummy, dummy, out=(a,)), 2)
2513+
assert_equal(np.add(a, dummy, out=a), 0)
2514+
2515+
def test_out_override(self):
2516+
# regression test for github bug 4753
2517+
class OutClass(ndarray):
2518+
def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
2519+
if 'out' in kw:
2520+
tmp_kw = kw.copy()
2521+
tmp_kw.pop('out')
2522+
func = getattr(ufunc, method)
2523+
kw['out'][...] = func(*inputs, **tmp_kw)
2524+
2525+
A = np.array([0]).view(OutClass)
2526+
B = np.array([5])
2527+
C = np.array([6])
2528+
np.multiply(C, B, A)
2529+
assert_equal(A[0], 30)
2530+
assert_(isinstance(A, OutClass))
2531+
A[0] = 0
2532+
np.multiply(C, B, out=A)
2533+
assert_equal(A[0], 30)
2534+
assert_(isinstance(A, OutClass))
2535+
24782536

24792537
class TestCAPI(TestCase):
24802538
def test_IsPythonScalar(self):

0 commit comments

Comments
 (0)
0