8000 MNT Replace pytest.warns(None) in test_neighbors (#23142) · thomasjpfan/scikit-learn@e61040c · GitHub
[go: up one dir, main page]

Skip to content

Commit e61040c

Browse files
authored
MNT Replace pytest.warns(None) in test_neighbors (scikit-learn#23142)
1 parent 5817dce commit e61040c

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from itertools import product
2+
from contextlib import nullcontext
23

34
import pytest
45
import re
@@ -1529,15 +1530,14 @@ def test_neighbors_metrics(
15291530
neigh.fit(X_train)
15301531

15311532
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
1532-
ExceptionToAssert = None
15331533
if (
15341534
metric == "wminkowski"
15351535
and algorithm == "brute"
15361536
and sp_version >= parse_version("1.6.0")
15371537
):
1538-
ExceptionToAssert = FutureWarning
1539-
1540-
with pytest.warns(ExceptionToAssert):
1538+
with pytest.warns(FutureWarning):
1539+
results[algorithm] = neigh.kneighbors(X_test, return_distance=True)
1540+
else:
15411541
results[algorithm] = neigh.kneighbors(X_test, return_distance=True)
15421542

15431543
brute_dst, brute_idx = results["brute"]
@@ -1576,14 +1576,14 @@ def test_kneighbors_brute_backend(
15761576
metric_params_list = _generate_test_params_for(metric, n_features)
15771577

15781578
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
1579-
ExceptionToAssert = None
1579+
warn_context_manager = nullcontext()
15801580
if metric == "wminkowski" and sp_version >= parse_version("1.6.0"):
15811581
if global_dtype == np.float64:
15821582
# Warning from sklearn.metrics._dist_metrics.WMinkowskiDistance
1583-
ExceptionToAssert = FutureWarning
1583+
warn_context_manager = pytest.warns(FutureWarning)
15841584
if global_dtype == np.float32:
15851585
# Warning from Scipy
1586-
ExceptionToAssert = DeprecationWarning
1586+
warn_context_manager = pytest.warns(DeprecationWarning)
15871587

15881588
for metric_params in metric_params_list:
15891589
p = metric_params.pop("p", 2)
@@ -1597,7 +1597,7 @@ def test_kneighbors_brute_backend(
15971597
)
15981598

15991599
neigh.fit(X_train)
1600-
with pytest.warns(ExceptionToAssert):
1600+
with warn_context_manager:
16011601
with config_context(enable_cython_pairwise_dist=False):
16021602
# Use the legacy backend for brute
16031603
legacy_brute_dst, legacy_brute_idx = neigh.kneighbors(
@@ -2103,9 +2103,9 @@ def test_radius_neighbors_brute_backend(
21032103
metric_params_list = _generate_test_params_for(metric, n_features)
21042104

21052105
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
2106-
ExceptionToAssert = None
2106+
warn_context_manager = nullcontext()
21072107
if metric == "wminkowski" and sp_version >= parse_version("1.6.0"):
2108-
ExceptionToAssert = FutureWarning
2108+
warn_context_manager = pytest.warns(FutureWarning)
21092109

21102110
for metric_params in metric_params_list:
21112111
p = metric_params.pop("p", 2)
@@ -2119,7 +2119,7 @@ def test_radius_neighbors_brute_backend(
21192119
)
21202120

21212121
neigh.fit(X_train)
2122-
with pytest.warns(ExceptionToAssert):
2122+
with warn_context_manager:
21232123
with config_context(enable_cython_pairwise_dist=False):
21242124
# Use the legacy backend for brute
21252125
legacy_brute_dst, legacy_brute_idx = neigh.radius_neighbors(

0 commit comments

Comments
 (0)
0