1
1
from itertools import product
2
+ from contextlib import nullcontext
2
3
3
4
import pytest
4
5
import re
@@ -1529,15 +1530,14 @@ def test_neighbors_metrics(
1529
1530
neigh .fit (X_train )
1530
1531
1531
1532
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
1532
- ExceptionToAssert = None
1533
1533
if (
1534
1534
metric == "wminkowski"
1535
1535
and algorithm == "brute"
1536
1536
and sp_version >= parse_version ("1.6.0" )
1537
1537
):
1538
- ExceptionToAssert = FutureWarning
1539
-
1540
- with pytest . warns ( ExceptionToAssert ) :
1538
+ with pytest . warns ( FutureWarning ):
1539
+ results [ algorithm ] = neigh . kneighbors ( X_test , return_distance = True )
1540
+ else :
1541
1541
results [algorithm ] = neigh .kneighbors (X_test , return_distance = True )
1542
1542
1543
1543
brute_dst , brute_idx = results ["brute" ]
@@ -1576,14 +1576,14 @@ def test_kneighbors_brute_backend(
1576
1576
metric_params_list = _generate_test_params_for (metric , n_features )
1577
1577
1578
1578
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
1579
- ExceptionToAssert = None
1579
+ warn_context_manager = nullcontext ()
1580
1580
if metric == "wminkowski" and sp_version >= parse_version ("1.6.0" ):
1581
1581
if global_dtype == np .float64 :
1582
1582
# Warning from sklearn.metrics._dist_metrics.WMinkowskiDistance
1583
- ExceptionToAssert = FutureWarning
1583
+ warn_context_manager = pytest . warns ( FutureWarning )
1584
1584
if global_dtype == np .float32 :
1585
1585
# Warning from Scipy
1586
- ExceptionToAssert = DeprecationWarning
1586
+ warn_context_manager = pytest . warns ( DeprecationWarning )
1587
1587
1588
1588
for metric_params in metric_params_list :
1589
1589
p = metric_params .pop ("p" , 2 )
@@ -1597,7 +1597,7 @@ def test_kneighbors_brute_backend(
1597
1597
)
1598
1598
1599
1599
neigh .fit (X_train )
1600
- with pytest . warns ( ExceptionToAssert ) :
1600
+ with warn_context_manager :
1601
1601
with config_context (enable_cython_pairwise_dist = False ):
1602
1602
# Use the legacy backend for brute
1603
1603
legacy_brute_dst , legacy_brute_idx = neigh .kneighbors (
@@ -2103,9 +2103,9 @@ def test_radius_neighbors_brute_backend(
2103
2103
metric_params_list = _generate_test_params_for (metric , n_features )
2104
2104
2105
2105
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
2106
- ExceptionToAssert = None
2106
+ warn_context_manager = nullcontext ()
2107
2107
if metric == "wminkowski" and sp_version >= parse_version ("1.6.0" ):
2108
- ExceptionToAssert = FutureWarning
2108
+ warn_context_manager = pytest . warns ( FutureWarning )
2109
2109
2110
2110
for metric_params in metric_params_list :
2111
2111
p = metric_params .pop ("p" , 2 )
@@ -2119,7 +2119,7 @@ def test_radius_neighbors_brute_backend(
2119
2119
)
2120
2120
2121
2121
neigh .fit (X_train )
2122
- with pytest . warns ( ExceptionToAssert ) :
2122
+ with warn_context_manager :
2123
2123
with config_context (enable_cython_pairwise_dist = False ):
2124
2124
# Use the legacy backend for brute
2125
2125
legacy_brute_dst , legacy_brute_idx = neigh .radius_neighbors (
0 commit comments