8000 Merge pull request #301 from crusaderky/test_assert_equal · data-apis/array-api-extra@dc44205 · GitHub
[go: up one dir, main page]

Skip to content

Commit dc44205

Browse files
authored
Merge pull request #301 from crusaderky/test_assert_equal
TST: rework tests for `xp_assert_equal`
2 parents e4ecb82 + 11b535c commit dc44205
8000

File tree

1 file changed

+132
-139
lines changed

1 file changed

+132
-139
lines changed

tests/test_testing.py

Lines changed: 132 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Callable
2-
from contextlib import nullcontext
32
from types import ModuleType
43 8000
from typing import cast
54

@@ -21,160 +20,154 @@
2120
from array_api_extra._lib._utils._typing import Array, Device
2221
from array_api_extra.testing import lazy_xp_function
2322

24-
# mypy: disable-error-code=decorated-any
23+
# mypy: disable-error-code="decorated-any, explicit-any"
2524
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
2625

27-
param_assert_equal_close = pytest.mark.parametrize(
28-
"func",
29-
[
30-
xp_assert_equal,
31-
xp_assert_less,
32-
pytest.param(
33-
xp_assert_close,
34-
marks=pytest.mark.xfail_xp_backend(
35-
Backend.SPARSE, reason="no isdtype", strict=False
36-
),
37-
),
38-
],
39-
)
40-
4126

4227
def test_as_numpy_array(xp: ModuleType, device: Device):
4328
x = xp.asarray([1, 2, 3], device=device)
4429
y = as_numpy_array(x, xp=xp)
4530
assert isinstance(y, np.ndarray)
4631

4732

48-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)
49-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
50-
def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
51-
func(xp.asarray(0), xp.asarray(0))
52-
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
53-
54-
with< 6D40 /span> pytest.raises(AssertionError, match="shapes do not match"):
55-
func(xp.asarray([0]), xp.asarray([[0]]))
56-
57-
with pytest.raises(AssertionError, match="dtypes do not match"):
58-
func(xp.asarray(0, dtype=xp.float32), xp.asarray(0, dtype=xp.float64))
59-
60-
with pytest.raises(AssertionError):
61-
func(xp.asarray([1, 2]), xp.asarray([1, 3]))
62-
63-
with pytest.raises(AssertionError, match="hello"):
64-
func(xp.asarray([1, 2]), xp.asarray([1, 3]), err_msg="hello")
65-
66-
67-
@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
68-
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="test other ns vs. numpy")
69-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
70-
def test_assert_close_equal_less_namespace(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
71-
with pytest.raises(AssertionError, match="namespaces do not match"):
72-
func(xp.asarray(0), np.asarray(0))
73-
with pytest.raises(TypeError, match="Unrecognized array input"):
74-
func(xp.asarray(0), 0)
75-
with pytest.raises(TypeError, match="list is not a supported array type"):
76-
func(xp.asarray([0]), [0])
77-
78-
79-
@param_assert_equal_close
80-
@pytest.mark.parametrize("check_shape", [False, True])
81-
def test_assert_close_equal_less_shape( # type: ignore[explicit-any]
82-
xp: ModuleType,
83-
func: Callable[..., None],
84-
check_shape: bool,
85-
):
86-
context = (
87-
pytest.raises(AssertionError, match="shapes do not match")
88-
if check_shape
89-
else nullcontext()
90-
)
91-
with context:
92-
# note: NaNs are handled by all 3 checks
93-
func(xp.asarray([xp.nan, xp.nan]), xp.asarray(xp.nan), check_shape=check_shape)
94-
95-
96-
@param_assert_equal_close
97-
@pytest.mark.parametrize("check_dtype", [False, True])
98-
def test_assert_close_equal_less_dtype( # type: ignore[explicit-any]
99-
xp: ModuleType,
100-
func: Callable[..., None],
101-
check_dtype: bool,
102-
):
103-
context = (
104-
pytest.raises(AssertionError, match="dtypes do not match")
105-
if check_dtype
106-
else nullcontext()
107-
)
108-
with context:
109-
func(
110-
xp.asarray(xp.nan, dtype=xp.float32),
111-
xp.asarray(xp.nan, dtype=xp.float64),
112-
check_dtype=check_dtype,
113-
)
114-
115-
116-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
117-
@pytest.mark.parametrize("check_scalar", [False, True])
118-
def test_assert_close_equal_less_scalar( # type: ignore[explicit-any]
119-
xp: ModuleType,
120-
func: Callable[..., None],
121-
check_scalar: bool,
122-
):
123-
context = (
124-
pytest.raises(AssertionError, match="array-ness does not match")
125-
if check_scalar
< 10000 div aria-hidden="true" class="position-absolute top-0 d-flex user-select-none DiffLineTableCellParts-module__comment-indicator--eI0hb">
126-
else nullcontext()
33+
class TestAssertEqualCloseLess:
34+
pr_assert_close = pytest.param( # pyright: ignore[reportUnannotatedClassAttribute]
35+
xp_assert_close,
36+
marks=pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype"),
12737
)
128-
with context:
129-
func(np.asarray(xp.nan), np.asarray(xp.nan)[()], check_scalar=check_scalar)
13038

39+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close])
40+
def test_assert_equal_close_basic(self, xp: ModuleType, func: Callable[..., None]):
41+
func(xp.asarray(0), xp.asarray(0))
42+
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
13143

132-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
133-
def test_assert_close_tolerance(xp: ModuleType):
134-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.03)
135-
with pytest.raises(AssertionError):
136-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.01)
44+
with pytest.raises(AssertionError, match="Mismatched elements"):
45+
func(xp.asarray([1, 2]), xp.asarray([2, 1]))
13746

138-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=3)
139-
with pytest.raises(AssertionError):
140-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1)
47+
with pytest.raises(AssertionError, match="hello"):
48+
func(xp.asarray([1, 2]), xp.asarray([2, 1]), err_msg="hello")
14149

50+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
51+
def test_shape_dtype(self, xp: ModuleType, func: Callable[..., None]):
52+
with pytest.raises(AssertionError, match="shapes do not match"):
53+
func(xp.asarray([0]), xp.asarray([[0]]))
14254

143-
def test_assert_less_basic(xp: ModuleType):
144-
xp_assert_less(xp.asarray(-1), xp.asarray(0))
145-
xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3]))
146-
with pytest.raises(AssertionError):
147-
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))
148-
with pytest.raises(AssertionError, match="hello"):
149-
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]), err_msg="hello")
55+
with pytest.raises(AssertionError, match="dtypes do not match"):
56+
func(xp.asarray(0, dtype=xp.float32), xp.asarray(0, dtype=xp.float64))
15057

151-
152-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
153-
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
154-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
155-
def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
156-
"""On Dask and other lazy backends, test that a shape with NaN's or None's
157-
can be compared to a real shape.
158-
"""
159-
a = xp.asarray([1, 2])
160-
a = a[a > 1]
161-
162-
func(a, xp.asarray([2]))
163-
with pytest.raises(AssertionError):
164-
func(a, xp.asarray([2, 3]))
165-
with pytest.raises(AssertionError):
166-
func(a, xp.asarray(2))
167-
with pytest.raises(AssertionError):
168-
func(a, xp.asarray([3]))
169-
170-
# Swap actual and desired
171-
func(xp.asarray([2]), a)
172-
with pytest.raises(AssertionError):
173-
func(xp.asarray([2, 3]), a)
174-
with pytest.raises(AssertionError):
175-
func(xp.asarray(2), a)
176-
with pytest.raises(AssertionError):
177-
func(xp.asarray([3]), a)
58+
@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
59+
@pytest.mark.skip_xp_backend(
60+
Backend.NUMPY_READONLY, reason="test other ns vs. numpy"
61+
)
62+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
63+
def test_namespace(self, xp: ModuleType, func: Callable[..., None]):
64+
with pytest.raises(AssertionError, match="namespaces do not match"):
65+
func(xp.asarray(0), np.asarray(0))
66+
with pytest.raises(TypeError, match="Unrecognized array input"):
67+
func(xp.asarray(0), 0)
68+
with pytest.raises(TypeError, match="list is not a supported array type"):
69+
func(xp.asarray([0]), [0])
70+
71+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
72+
def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
73+
a = xp.asarray([1] if func is xp_assert_less else [2])
74+
b = xp.asarray(2)
75+
c = xp.asarray(0)
76+
d = xp.asarray([2, 2])
77+
78+
with pytest.raises(AssertionError, match="shapes do not match"):
79+
func(a, b)
80+
func(a, b, check_shape=False)
81+
with pytest.raises(AssertionError, match="Mismatched elements"):
82+
func(a, c, check_shape=False)
83+
with pytest.raises(AssertionError, match=r"shapes \(1,\), \(2,\) mismatch"):
84+
func(a, d, check_shape=False)
85+
86+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
87+
def test_check_dtype(self, xp: ModuleType, func: Callable[..., None]):
88+
a = xp.asarray(1 if func is xp_assert_less else 2)
89+
b = xp.asarray(2, dtype=xp.int16)
90+
c = xp.asarray(0, dtype=xp.int16)
91+
92+
with pytest.raises(AssertionError, match="dtypes do not match"):
93+
func(a, b)
94+
func(a, b, check_dtype=False)
95+
with pytest.raises(AssertionError, match="Mismatched elements"):
96+
func(a, c, check_dtype=False)
97+
98+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
99+
@pytest.mark.xfail_xp_backend(
100+
Backend.SPARSE, reason="sparse [()] returns np.generic"
101+
)
102+
def test_check_scalar(
103+
self, xp: ModuleType, library: Backend, func: Callable[..., None]
104+
):
105+
a = xp.asarray(1 if func is xp_assert_less else 2)
106+
b = xp.asarray(2)[()] # Note: only makes a difference on NumPy
107+
c = xp.asarray(0)
108+
109+
func(a, b)
110+
if library.like(Backend.NUMPY):
111+
with pytest.raises(AssertionError, match="array-ness does not match"):
112+
func(a, b, check_scalar=True)
113+
else:
114+
func(a, b, check_scalar=True)
115+
with pytest.raises(AssertionError, match="Mismatched elements"):
116+
func(a, c, check_scalar=True)
117+
118+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
119+
@pytest.mark.parametrize("dtype", ["int64", "float64"])
120+
def test_assert_close_tolerance(self, dtype: str, xp: ModuleType):
121+
a = xp.asarray([100], dtype=getattr(xp, dtype))
122+
b = xp.asarray([102], dtype=getattr(xp, dtype))
123+
124+
with pytest.raises(AssertionError, match="Mismatched elements"):
125+
xp_assert_close(a, b)
126+
127+
xp_assert_close(a, b, rtol=0.03)
128+
with pytest.raises(AssertionError, match="Mismatched elements"):
129+
xp_assert_close(a, b, rtol=0.01)
130+
131+
xp_assert_close(a, b, atol=3)
132+
with pytest.raises(AssertionError, match="Mismatched elements"):
133+
xp_assert_close(a, b, atol=1)
134+
135+
def test_assert_less(self, xp: ModuleType):
136+
xp_assert_less(xp.asarray(-1), xp.asarray(0))
137+
xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3]))
138+
with pytest.raises(AssertionError, match="Mismatched elements"):
139+
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))
140+
141+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
142+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
143+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
144+
def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
145+
"""On Dask and other lazy backends, test that a shape with NaN's or None's
146+
can be compared to a real shape.
147+
"""
148+
# actual has shape=(None, )
149+
a = xp.asarray([1] if func is xp_assert_less else [2])
150+
a = a[a > 0]
151+
152+
func(a, xp.asarray([2]))
153+
with pytest.raises(AssertionError, match="shapes do not match"):
154+
func(a, xp.asarray(2))
155+
with pytest.raises(AssertionError, match="shapes do not match"):
156+
func(a, xp.asarray([2, 3]))
157+
with pytest.raises(AssertionError, match="Mismatched elements"):
158+
func(a, xp.asarray([0]))
159+
160+
# desired has shape=(None, )
161+
a = xp.asarray([3] if func is xp_assert_less else [2])
162+
a = a[a > 0]
163+
164+
func(xp.asarray([2]), a)
165+
with pytest.raises(AssertionError, match="shapes do not match"):
166+
func(xp.asarray(2), a)
167+
with pytest.raises(AssertionError, match="shapes do not match"):
168+
func(xp.asarray([2, 3]), a)
169+
with pytest.raises(AssertionError, match="Mismatched elements"):
170+
func(xp.asarray([4]), a)
178171

179172

180173
def good_lazy(x: Array) -> Array:

0 commit comments

Comments
 (0)
0