8000 TST Adapts wminkowski for scipy 1.6.0 (#19096) · scikit-learn/scikit-learn@ef4e95f · GitHub
[go: up one dir, main page]

Skip to content

Commit ef4e95f

Browse files
authored
TST Adapts wminkowski for scipy 1.6.0 (#19096)
1 parent e325bf7 commit ef4e95f

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

sklearn/neighbors/tests/test_dist_metrics.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,19 @@ def test_cdist(metric):
5555
keys = argdict.keys()
5656
for vals in itertools.product(*argdict.values()):
5757
kwargs = dict(zip(keys, vals))
58-
D_true = cdist(X1, X2, metric, **kwargs)
58+
if metric == "wminkowski":
59+
if sp_version >= parse_version("1.8.0"):
60+
pytest.skip("wminkowski will be removed in SciPy 1.8.0")
61+
62+
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
63+
ExceptionToAssert = None
64+
if sp_version >= parse_version("1.6.0"):
65+
ExceptionToAssert = DeprecationWarning
66+
with pytest.warns(ExceptionToAssert):
67+
D_true = cdist(X1, X2, metric, **kwargs)
68+
else:
69+
D_true = cdist(X1, X2, metric, **kwargs)
70+
5971
check_cdist(metric, kwargs, D_true)
6072

6173

@@ -83,7 +95,19 @@ def test_pdist(metric):
8395
keys = argdict.keys()
8496
for vals in itertools.product(*argdict.values()):
8597
kwargs = dict(zip(keys, vals))
86-
D_true = cdist(X1, X1, metric, **kwargs)
98+
if metric == "wminkowski":
99+
if sp_version >= parse_version("1.8.0"):
100+
pytest.skip("wminkowski will be removed in SciPy 1.8.0")
101+
102+
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
103+
ExceptionToAssert = None
104+
if sp_version >= parse_version("1.6.0"):
105+
ExceptionToAssert = DeprecationWarning
106+
with pytest.warns(ExceptionToAssert):
107+
D_true = cdist(X1, X1, metric, **kwargs)
108+
else:
109+
D_true = cdist(X1, X1, metric, **kwargs)
110+
87111
check_pdist(metric, kwargs, D_true)
88112

89113

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sklearn.utils._testing import assert_raise_message
2727
from sklearn.utils._testing import ignore_warnings
2828
from sklearn.utils.validation import check_random_state
29+
from sklearn.utils.fixes import sp_version, parse_version
2930

3031
import joblib
3132

@@ -1244,6 +1245,9 @@ def test_neighbors_metrics(n_samples=20, n_features=3,
12441245
test = rng.rand(n_query_pts, n_features)
12451246

12461247
for metric, metric_params in metrics:
1248+
if metric == "wminkowski" and sp_version >= parse_version("1.8.0"):
1249+
# wminkowski will be removed in SciPy 1.8.0
1250+
continue
12471251
results = {}
12481252
p = metric_params.pop('p', 2)
12491253
for algorithm in algorithms:
@@ -1265,8 +1269,16 @@ def test_neighbors_metrics(n_samples=20, n_features=3,
12651269
if metric == 'haversine' else slice(None))
12661270

12671271
neigh.fit(X[:, feature_sl])
1268-
results[algorithm] = neigh.kneighbors(test[:, feature_sl],
1269-
return_distance=True)
1272+
1273+
# wminkoski is deprecated in SciPy 1.6.0 and removed in 1.8.0
1274+
ExceptionToAssert = None
1275+
if (metric == "wminkowski" and algorithm == 'brute'
1276+
and sp_version >= parse_version("1.6.0")):
1277+
ExceptionToAssert = DeprecationWarning
1278+
1279+
with pytest.warns(ExceptionToAssert):
1280+
results[algorithm] = neigh.kneighbors(test[:, feature_sl],
1281+
return_distance=True)
12701282

12711283
assert_array_almost_equal(results['brute'][0], results['ball_tree'][0])
12721284
assert_array_almost_equal(results['brute'][1], results['ball_tree'][1])

0 commit comments

Comments
 (0)
0