8000 API Deprecate metrics other than euclidean and manhattan for NearestCentroid by Valentin-Laurent · Pull Request #24083 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Merged
4 changes: 4 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,10 @@ Changelog
when `n_neighbors` is large and `algorithm="brute"` with non Euclidean metrics.
:pr:`24076` by :user:`Meekail Zain <micky774>`, :user:`Julien Jerphanion <jjerphan>`.

- |API| The support for metrics other than `euclidean` and `manhattan` and for
callables in :class:`neighbors.NearestNeighbors` is deprecated and will be removed in
version 1.5. :pr:`24083` by :user:`Valentin Laurent <Valentin-Laurent>`.

:mod:`sklearn.neural_network`
.............................

Expand Down
26 changes: 24 additions & 2 deletions sklearn/neighbors/_nearest_centroid.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class NearestCentroid(ClassifierMixin, BaseEstimator):
If the `"manhattan"` metric is provided, this centroid is the median
and for all other metrics, the centroid is now set to be the mean.

.. deprecated:: 1.3
Support for metrics other than `euclidean` and `manhattan` and for
callables was deprecated in version 1.3 and will be removed in
version 1.5.

.. versionchanged:: 0.19
`metric='precomputed'` was deprecated and now raises an error

Expand Down Expand Up @@ -101,10 +106,12 @@ class NearestCentroid(ClassifierMixin, BaseEstimator):
[1]
"""

_valid_metrics = set(_VALID_METRICS) - {"mahalanobis", "seuclidean", "wminkowski"}

_parameter_constraints: dict = {
"metric": [
StrOptions(
set(_VALID_METRICS) - {"mahalanobis", "seuclidean", "wminkowski"}
_valid_metrics, deprecated=_valid_metrics - {"manhattan", "euclidean"}
),
callable,
],
Expand Down Expand Up @@ -134,6 +141,20 @@ def fit(self, X, y):
Fitted estimator.
"""
self._validate_params()

if isinstance(self.metric, str) and self.metric not in (
"manhattan",
"euclidean",
):
warnings.warn(
(
"Support for distance metrics other than euclidean and "
"manhattan and for callables was deprecated in version "
"1.3 and will be removed in version 1.5."
),
FutureWarning,
)

# If X is sparse and the metric is "manhattan", store it in a csc
# format is easier to calculate the median.
if self.metric == "manhattan":
Expand Down Expand Up @@ -167,14 +188,14 @@ def fit(self, X, y):
if is_X_sparse:
center_mask = np.where(center_mask)[0]

# XXX: Update other averaging methods according to the metrics.
if self.metric == "manhattan":
# NumPy does not calculate median of sparse matrices.
if not is_X_sparse:
self.centroids_[cur_class] = np.median(X[center_mask], axis=0)
else:
self.centroids_[cur_class] = csc_median_axis_0(X[center_mask])
else:
# TODO(1.5) remove warning when metric is only manhattan or euclidean
if self.metric != "euclidean":
warnings.warn(
"Averaging for metrics other than "
Expand Down Expand Up @@ -209,6 +230,7 @@ def fit(self, X, y):
self.centroids_ = dataset_centroid_[np.newaxis, :] + msd
return self

# TODO(1.5) remove note about precomputed metric
def predict(self, X):
"""Perform classification on an array of test vectors `X`.

Expand Down
18 changes: 18 additions & 0 deletions sklearn/neighbors/tests/test_nearest_centroid.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def test_classification_toy():
assert_array_equal(clf.predict(T_csr.tolil()), true_result)


# TODO(1.5): Remove filterwarnings when support for some metrics is removed
@pytest.mark.filterwarnings("ignore:Support for distance metrics:FutureWarning:sklearn")
def test_iris():
# Check consistency on dataset iris.
for metric in ("euclidean", "cosine"):
Expand All @@ -61,6 +63,8 @@ def test_iris():
assert score > 0.9, "Failed with score = " + str(score)


# TODO(1.5): Remove filterwarnings when support for some metrics is removed
@pytest.mark.filterwarnings("ignore:Support for distance metrics:FutureWarning:sklearn")
def test_iris_shrinkage():
# Check consistency on dataset iris, when using shrinkage.
for metric in ("euclidean", "cosine"):
Expand Down Expand Up @@ -142,6 +146,20 @@ def test_manhattan_metric():
assert_array_equal(dense_centroid, [[-1, -1], [1, 1]])


# TODO(1.5): remove this test
@pytest.mark.parametrize(
"metric", sorted(list(NearestCentroid._valid_metrics - {"manhattan", "euclidean"}))
)
def test_deprecated_distance_metric_supports(metric):
# Check that a warning is raised for all deprecated distance metric supports
clf = NearestCentroid(metric=metric)
with pytest.warns(
FutureWarning,
match="Support for distance metrics other than euclidean and manhattan",
):
clf.fit(X, y)


def test_features_zero_var():
# Test that features with 0 variance throw error

Expand Down
0