8000 API Deprecate metrics other than euclidean and manhattan for NearestC… · REDVM/scikit-learn@af7ef17 · GitHub
[go: up one dir, main page]

Skip to content

Commit af7ef17

Browse files
Valentin-Laurentjeremiedbbjjerphan
authored andcommitted
API Deprecate metrics other than euclidean and manhattan for NearestCentroid (scikit-learn#24083)
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 8706d3e commit af7ef17

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

doc/whats_new/v1.3.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,10 @@ Changelog
495495
when `n_neighbors` is large and `algorithm="brute"` with non Euclidean metrics.
496496
:pr:`24076` by :user:`Meekail Zain <micky774>`, :user:`Julien Jerphanion <jjerphan>`.
497497

498+
- |API| The support for metrics other than `euclidean` and `manhattan` and for
499+
callables in :class:`neighbors.NearestNeighbors` is deprecated and will be removed in
500+
version 1.5. :pr:`24083` by :user:`Valentin Laurent <Valentin-Laurent>`.
501+
498502
:mod:`sklearn.neural_network`
499503
.............................
500504

sklearn/neighbors/_nearest_centroid.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ class NearestCentroid(ClassifierMixin, BaseEstimator):
4747
If the `"manhattan"` metric is provided, this centroid is the median
4848
and for all other metrics, the centroid is now set to be the mean.
4949
50+
.. deprecated:: 1.3
51+
Support for metrics other than `euclidean` and `manhattan` and for
52+
callables was deprecated in version 1.3 and will be removed in
53+
version 1.5.
54+
5055
.. versionchanged:: 0.19
5156
`metric='precomputed'` was deprecated and now raises an error
5257
@@ -101,10 +106,12 @@ class NearestCentroid(ClassifierMixin, BaseEstimator):
101106
[1]
102107
"""
103108

109+
_valid_metrics = set(_VALID_METRICS) - {"mahalanobis", "seuclidean", "wminkowski"}
110+
104111
_parameter_constraints: dict = {
105112
"metric": [
106113
StrOptions(
107-
set(_VALID_METRICS) - {"mahalanobis", "seuclidean", "wminkowski"}
114+
_valid_metrics, deprecated=_valid_metrics - {"manhattan", "euclidean"}
108115
),
109116
callable,
110117
],
@@ -134,6 +141,20 @@ def fit(self, X, y):
134141
Fitted estimator.
135142
"""
136143
self._validate_params()
144+
145+
if isinstance(self.metric, str) and self.metric not in (
146+
"manhattan",
147+
"euclidean",
148+
):
149+
warnings.warn(
150+
(
151+
"Support for distance metrics other than euclidean and "
152+
"manhattan and for callables was deprecated in version "
153+
"1.3 and will be removed in version 1.5."
154+
),
155+
FutureWarning,
156+
)
157+
137158
# If X is sparse and the metric is "manhattan", store it in a csc
138159
# format is easier to calculate the median.
139160
if self.metric == "manhattan":
@@ -167,14 +188,14 @@ def fit(self, X, y):
167188
if is_X_sparse:
168189
center_mask = np.where(center_mask)[0]
169190

170-
# XXX: Update other averaging methods according to the metrics.
171191
if self.metric == "manhattan":
172192
# NumPy does not calculate median of sparse matrices.
173193
if not is_X_sparse:
174194
self.centroids_[cur_class] = np.median(X[center_mask], axis=0)
175195
else:
176196
self.centroids_[cur_class] = csc_median_axis_0(X[center_mask])
177197
else:
198+
# TODO(1.5) remove warning when metric is only manhattan or euclidean
178199
if self.metric != "euclidean":
179200
warnings.warn(
180201
"Averaging for metrics other than "
@@ -209,6 +230,7 @@ def fit(self, X, y):
209230
self.centroids_ = dataset_centroid_[np.newaxis, :] + msd
210231
return self
211232

233+
# TODO(1.5) remove note about precomputed metric
212234
def predict(self, X):
213235
"""Perform classification on an array of test vectors `X`.
214236

sklearn/neighbors/tests/test_nearest_centroid.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def test_classification_toy():
5353
assert_array_equal(clf.predict(T_csr.tolil()), true_result)
5454

5555

56+
# TODO(1.5): Remove filterwarnings when support for some metrics is removed
57+
@pytest.mark.filterwarnings("ignore:Support for distance metrics:FutureWarning:sklearn")
5658
def test_iris():
5759
# Check consistency on dataset iris.
5860
for metric in ("euclidean", "cosine"):
@@ -61,6 +63,8 @@ def test_iris():
6163
assert score > 0.9, "Failed with score = " + str(score)
6264

6365

66+
# TODO(1.5): Remove filterwarnings when support for some metrics is removed
67+
@pytest.mark.filterwarnings("ignore:Support for distance metrics:FutureWarning:sklearn")
6468
def test_iris_shrinkage():
6569
# Check consistency on dataset iris, when using shrinkage.
6670
for metric in ("euclidean", "cosine"):
@@ -142,6 +146,20 @@ def test_manhattan_metric():
142146
assert_array_equal(dense_centroid, [[-1, -1], [1, 1]])
143147

144148

149+
# TODO(1.5): remove this test
150+
@pytest.mark.parametrize(
151+
"metric", sorted(list(NearestCentroid._valid_metrics - {"manhattan", "euclidean"}))
152+
)
153+
def test_deprecated_distance_metric_supports(metric):
154+
# Check that a warning is raised for all deprecated distance metric supports
155+
clf = NearestCentroid(metric=metric)
156+
with pytest.warns(
157+
FutureWarning,
158+
match="Support for distance metrics other than euclidean and manhattan",
159+
):
160+
clf.fit(X, y)
161+
162+
145163
def test_features_zero_var():
146164
# Test that features with 0 variance throw error
147165

0 commit comments

Comments
 (0)
0