8000 MAINT: Fixes tests with __array_function__ disabled · numpy/numpy@9216a1d · GitHub
[go: up one dir, main page]

Skip to content

Commit 9216a1d

Browse files
committed
MAINT: Fixes tests with __array_function__ disabled
1 parent 37df5e6 commit 9216a1d

File tree

6 files changed

+44
-10
lines changed

6 files changed

+44
-10
lines changed

numpy/core/overrides.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from numpy.compat._inspect import getargspec
1010

1111

12-
ENABLE_ARRAY_FUNCTION = bool(
12+
ARRAY_FUNCTION_ENABLED = bool(
1313
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1)))
1414

1515

@@ -142,7 +142,7 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
142142
Function suitable for decorating the implementation of a NumPy function.
143143
"""
144144

145-
if not ENABLE_ARRAY_FUNCTION:
145+
if not ARRAY_FUNCTION_ENABLED:
146146
def decorator(implementation):
147147
if docs 8000 _from_dispatcher:
148148
add_docstring(implementation, dispatcher.__doc__)

numpy/core/shape_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def vstack(tup):
273273
[4]])
274274
275275
"""
276+
if not overrides.ARRAY_FUNCTION_ENABLED:
277+
# raise warning if necessary
278+
_arrays_for_stack_dispatcher(tup, stacklevel=2)
276279
return _nx.concatenate([atleast_2d(_m) for _m in tup], 0)
277280

278281

@@ -324,6 +327,10 @@ def hstack(tup):
324327
[3, 4]])
325328
326329
"""
330+
if not overrides.ARRAY_FUNCTION_ENABLED:
331+
# raise warning if necessary
332+
_arrays_for_stack_dispatcher(tup, stacklevel=2)
333+
327334
arrs = [atleast_1d(_m) for _m in tup]
328335
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal"
329336
if arrs and arrs[0].ndim == 1:
@@ -400,6 +407,10 @@ def stack(arrays, axis=0, out=None):
400407
[3, 4]])
401408
402409
"""
410+
if not overrides.ARRAY_FUNCTION_ENABLED:
411+
# raise warning if necessary
412+
_arrays_for_stack_dispatcher(arrays, stacklevel=2)
413+
403414
arrays = [asanyarray(arr) for arr in arrays]
404415
if not arrays:
405416
raise ValueError('need at least one array to stack')

numpy/core/tests/test_overrides.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
assert_, assert_equal, assert_raises, assert_raises_regex)
1010
from numpy.core.overrides import (
1111
_get_implementing_args, array_function_dispatch,
12-
verify_matching_signatures, ENABLE_ARRAY_FUNCTION)
12+
verify_matching_signatures, ARRAY_FUNCTION_ENABLED)
1313
from numpy.compat import pickle
1414
import pytest
1515

1616

1717
requires_array_function = pytest.mark.skipif(
18-
not ENABLE_ARRAY_FUNCTION,
18+
not ARRAY_FUNCTION_ENABLED,
1919
reason="__array_function__ dispatch not enabled.")
2020

2121

numpy/lib/shape_base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from numpy.core.numeric import (
88
asarray, zeros, outer, concatenate, array, asanyarray
99
)
10-
from numpy.core.fromnumeric import product, reshape, transpose
10+
from numpy.core.fromnumeric import reshape, transpose
1111
from numpy.core.multiarray import normalize_axis_index
1212
from numpy.core import overrides
1313
from numpy.core import vstack, atleast_3d
@@ -628,6 +628,10 @@ def column_stack(tup):
628628
[3, 4]])
629629
630630
"""
631+
if not overrides.ARRAY_FUNCTION_ENABLED:
632+
# raise warning if necessary
633+
_arrays_for_stack_dispatcher(tup, stacklevel=2)
634+
631635
arrays = []
632636
for v in tup:
633637
arr = array(v, copy=False, subok=True)
@@ -692,6 +696,10 @@ def dstack(tup):
692696
[[3, 4]]])
693697
694698
"""
699+
if not overrides.ARRAY_FUNCTION_ENABLED:
700+
# raise warning if necessary
701+
_arrays_for_stack_dispatcher(tup, stacklevel=2)
702+
695703
return _nx.concatenate([atleast_3d(_m) for _m in tup], 2)
696704

697705

numpy/lib/ufunclike.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
__all__ = ['fix', 'isneginf', 'isposinf']
99

1010
import numpy.core.numeric as nx
11-
from numpy.core.overrides import array_function_dispatch
11+
from numpy.core.overrides import (
12+
array_function_dispatch, ARRAY_FUNCTION_ENABLED,
13+
)
1214
import warnings
1315
import functools
1416

@@ -43,7 +45,7 @@ def _fix_out_named_y(f):
4345
Allow the out argument to be passed as the name `y` (deprecated)
4446
4547
This decorator should only be used if _deprecate_out_named_y is used on
46-
a corresponding dispatcher fucntion.
48+
a corresponding dispatcher function.
4749
"""
4850
@functools.wraps(f)
4951
def func(x, out=None, **kwargs):
@@ -55,13 +57,23 @@ def func(x, out=None, **kwargs):
5557
return func
5658

5759

60+
def _fix_and_maybe_deprecate_out_named_y(f):
61+
"""
62+
Use the appropriate decorator, depending upon if dispatching is being used.
63+
"""
64+
if ARRAY_FUNCTION_ENABLED:
65+
return _fix_out_named_y(f)
66+
else:
67+
return _deprecate_out_named_y(f)
68+
69+
5870
@_deprecate_out_named_y
5971
def _dispatcher(x, out=None):
6072
return (x, out)
6173

6274

6375
@array_function_dispatch(_dispatcher, verify=False, module='numpy')
64-
@_fix_out_named_y
76+
@_fix_and_maybe_deprecate_out_named_y
6577
def fix(x, out=None):
6678
"""
6779
Round to nearest integer towards zero.
@@ -108,7 +120,7 @@ def fix(x, out=None):
108120

109121

110122
@array_function_dispatch(_dispatcher, verify=False, module='numpy')
111-
@_fix_out_named_y
123+
@_fix_and_maybe_deprecate_out_named_y
112124
def isposinf(x, out=None):
113125
"""
114126
Test element-wise for positive infinity, return result as bool array.
@@ -177,7 +189,7 @@ def isposinf(x, out=None):
177189

178190

179191
@array_function_dispatch(_dispatcher, verify=False, module='numpy')
180-
@_fix_out_named_y
192+
@_fix_and_maybe_deprecate_out_named_y
181193
def isneginf(x, out=None):
182194
"""
183195
Test element-wise for negative infinity, return result as bool array.

numpy/testing/tests/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_,
1818
tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT
1919
)
20+
from numpy.core.overrides import ARRAY_FUNCTION_ENABLED
2021

2122

2223
class _GenericTest(object):
@@ -179,6 +180,8 @@ def __ne__(self, other):
179180
self._test_not_equal(a, b)
180181
self._test_not_equal(b, a)
181182

183+
@pytest.mark.skipif(
184+
not ARRAY_FUNCTION_ENABLED, reason='requires __array_function__')
182185
def test_subclass_that_does_not_implement_npall(self):
183186
class MyArray(np.ndarray):
184187
def __array_function__(self, *args, **kwargs):

0 commit comments

Comments
 (0)
0