8000 Update ufunc override to work properly with ufunc methods. by cowlicks · Pull Request #4626 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

Update ufunc override to work properly with ufunc methods. #4626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2014
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
10000 Loading
Diff view
Diff view
Prev Previous commit
TST: Add tests for ufunc.method overrides.
  • Loading branch information
cowlicks committed May 9, 2014
commit c8ac77c427a880eec455a05b92de9a2260d5f9d9
129 changes: 97 additions & 32 deletions numpy/core/tests/test_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,39 +1026,105 @@ def __numpy_ufunc__(self, func, method, pos, inputs, **kwargs):
def test_ufunc_override_methods(self):
class A(object):
def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
if method == "__call__":
return method
if method == "reduce":
return method
if method == "accumulate":
return method
if method == "reduceat":
return method
return self, ufunc, method, pos, inputs, kwargs

# __call__
a = A()
res = np.multiply(1, a)
assert_equal(res, "__call__")

res = np.multiply.reduce(1, a)
assert_equal(res, "reduce")

res = np.multiply.accumulate(1, a)
assert_equal(res, "accumulate")

res = np.multiply.reduceat(1, a)
assert_equal(res, "reduceat")

res = np.multiply(a, 1)
assert_equal(res, "__call__")

res = np.multiply.reduce(a, 1)
assert_equal(res, "reduce")

res = np.multiply.accumulate(a, 1)
assert_equal(res, "accumulate")

res = np.multiply.reduceat(a, 1)
assert_equal(res, "reduceat")
res = np.multiply.__call__(1, a, foo='bar', answer=42)
assert_equal(res[0], a)
assert_equal(res[1], np.multiply)
assert_equal(res[2], '__call__')
assert_equal(res[3], 1)
assert_equal(res[4], (1, a))
assert_equal(res[5], {'foo': 'bar', 'answer': 42})

# reduce, positional args
res = np.multiply.reduce(a, 'axis0', 'dtype0', 'out0', 'keep0')
assert_equal(res[0], a)
assert_equal(res[1], np.multiply)
assert_equal(res[2], 'reduce')
assert_equal(res[3], 0)
assert_equal(res[4], (a,))
assert_equal(res[5], {'dtype':'dtype0',
'out': 'out0',
'keepdims': 'keep0',
'axis': 'axis0'})

# reduce, kwargs
res = np.multiply.reduce(a, axis='axis0', dtype='dtype0', out='out0',
keepdims='keep0')
assert_equal(res[0], a)
assert_equal(res[1], np.multiply)
assert_equal(res[2], 'reduce')
assert_equal(res[3], 0)
assert_equal(res[4], (a,))
assert_equal(res[5], {'dtype':'dtype0',
'out': 'out0',
'keepdims': 'keep0',
'axis': 'axis0'})

# accumulate, pos args
res = np.multiply.accumulate(a, 'axis0', 'dtype0', 'out0')
assert_equal(res[0], a)
assert_equal(res[1], np.multiply)
assert_equal(res[2], 'accumulate')
assert_equal(res[3], 0)
assert_equal(res[4], (a,))
assert_equal(res[5], {'dtype':'dtype0',
'out': 'out0',
'axis': 'axis0'})

# accumulate, kwargs
res = np.multiply.accumulate(a, axis='axis0', dtype='dtype0',
out='out0')
assert_equal(res[0], a)
assert_equal(res[1], np.multiply)
assert_equal(res[2], 'accumulate')
assert_equal(res[3], 0)
assert_equal(res[4], (a,))
assert_equal(res[5], {'dtype':'dtype0',
'out': 'out0',
'axis': 'axis0'})

# reduceat, pos args
res = np.multiply.reduceat(a, [4, 2], 'axis0', 'dtype0', 'out0')
assert_equal(res[0], a)
assert_equal(res[1], np.multiply)
assert_equal(res[2], 'reduceat')
assert_equal(res[3], 0)
assert_equal(res[4], (a, [4, 2]))
assert_equal(res[5], {'dtype':'dtype0',
'out': 'out0',
'axis': 'axis0'})

# reduceat, kwargs
res = np.multiply.reduceat(a, [4, 2], axis='axis0', dtype='dtype0',
out='out0')
assert_equal(res[0], a)
assert_equal(res[1], np.multiply)
assert_equal(res[2], 'reduceat')
assert_equal(res[3], 0)
assert_equal(res[4], (a, [4, 2]))
assert_equal(res[5], {'dtype':'dtype0',
'out': 'out0',
'axis': 'axis0'})

# outer
res = np.multiply.outer(a, 42)
assert_equal(res[0], a)
assert_equal(res[1], np.multiply)
assert_equal(res[2], 'outer')
assert_equal(res[3], 0)
assert_equal(res[4], (a, 42))
assert_equal(res[5], {})

# at
res = np.multiply.at(a, [4, 2], 'b0')
assert_equal(res[0], a)
assert_equal(res[1], np.multiply)
assert_equal(res[2], 'at')
assert_equal(res[3], 0)
assert_equal(res[4], (a, [4, 2], 'b0'))

def test_ufunc_override_out(self):
class A(object):
Expand Down Expand Up @@ -1094,7 +1160,6 @@ def __numpy_ufunc__(self, ufunc, method, pos, inputs, **kwargs):
assert_equal(res7['out'][0], 'out0')
assert_equal(res7['out'][1], 'out1')


def test_ufunc_override_exception(self):
class A(object):
def __numpy_ufunc__(self, *a, **kwargs):
Expand Down
0