8000 API Deprecate metrics other than euclidean and manhattan for NearestCentroid by Valentin-Laurent · Pull Request #24083 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

API Deprecate metrics other than euclidean and manhattan for NearestCentroid #24083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
4 changes: 4 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,10 @@ Changelog
when `n_neighbors` is large and `algorithm="brute"` with non Euclidean metrics.
:pr:`24076` by :user:`Meekail Zain <micky774>`, :user:`Julien Jerphanion <jjerphan>`.

- |API| The support for metrics other than `euclidean` and `manhattan` and for
callables in :class:`neighbors.NearestNeighbors` is deprecated and will be removed in
version 1.5. :pr:`24083` by :user:`Valentin Laurent <Valentin-Laurent>`.

:mod:`sklearn.neural_network`
............................ 8000 .

Expand Down
26 changes: 24 additions & 2 deletions sklearn/neighbors/_nearest_centroid.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class NearestCentroid(ClassifierMixin, BaseEstimator):
If the `"manhattan"` metric is provided, this centroid is the median
and for all other metrics, the centroid is now set to be the mean.

.. deprecated:: 1.3
Support for metrics other than `euclidean` and `manhattan` and for
callables was deprecated in version 1.3 and will be removed in
version 1.5.

.. versionchanged:: 0.19
`metric='precomputed'` was deprecated and now raises an error

Expand Down Expand Up @@ -101,10 +106,12 @@ class NearestCentroid(ClassifierMixin, BaseEstimator):
[1]
"""

_valid_metrics = set(_VALID_METRICS) - {"mahalanobis", "seuclidean", "wminkowski"}

_parameter_constraints: dict = {
"metric": [
StrOptions(
set(_VALID_METRICS) - {"mahalanobis", "seuclidean", "wminkowski"}
_valid_metrics, deprecated=_valid_metrics - {"manhattan", "euclidean"}
),
callable,
],
Expand Down Expand Up @@ -134,6 +141,20 @@ def fit(self, X, y):
Fitted estimator.
"""
self._validate_params()

if isinstance(self.metric, str) and self.metric not in (
"manhattan",
"euclidean",
):
warnings.warn(
(
"Support for distance metrics other than euclidean and "
"manhattan and for callables was deprecated in version "
"1.3 and will be removed in version 1.5."
),
FutureWarning,
)

# If X is sparse and the metric is "manhattan", store it in a csc
# format is easier to calculate the median.
if self.metric == "manhattan":
Expand Down Expand Up @@ -167,14 +188,14 @@ def fit(self, X, y):
if is_X_sparse:
center_mask = np.where(center_mask)[0]

# XXX: Update other averaging methods according to the metrics.
if self.metric == "manhattan":
# NumPy does not calculate median of sparse matrices.
if not is_X_sparse:
self.centroids_[cur_class] = np.median(X[center_mask], axis=0)
else:
self.centroids_[cur_class] = csc_median_axis_0(X[center_mask])
else:
# TODO(1.5) remove warning when metric is only manhattan or euclidean
if self.metric != "euclidean":
warnings.warn(
"Averaging for metrics other than "
Expand Down Expand Up @@ -209,6 +230,7 @@ def fit(self, X, y):
self.centroids_ = dataset_centroid_[np.newaxis, :] + msd
return self

# TODO(1.5) remove note about precomputed metric
def predict(self, X):
"""Perform classification on an array of test vectors `X`.

Expand Down
18 changes: 18 additions & 0 deletions sklearn/neighbors/tests/test_nearest_centroid.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def test_classification_toy():
assert_array_equal(clf.predict(T_csr.tolil()), true_result)


# TODO(1.5): Remove filterwarnings when support for some metrics is removed
@pytest.mark.filterwarnings("ignore:Support for distance metrics:FutureWarning:sklearn")
def test_iris():
# Check consistency on dataset iris.
for metric in ("euclidean", "cosine"):
Expand All @@ -61,6 +63,8 @@ def test_iris():
assert score > 0.9, "Failed with score = " + str(score)


# TODO(1.5): Remove filterwarnings when support for some metrics is removed
@pytest.mark.filterwarnings("ignore:Support for distance metrics:FutureWarning:sklearn")
def test_iris_shrinkage():
# Check consistency on dataset iris, when using shrinkage.
for metric in ("euclidean", "cosine"):
Expand Down Expand Up @@ -142,6 +146,20 @@ def test_manhattan_metric():
assert_array_equal(dense_centroid, [[-1, -1], [1, 1]])


# TODO(1.5): remove this test
@pytest.mark.parametrize(
"metric", sorted(list(NearestCentroid._valid_metrics - {"manhattan", "euclidean"}))
)
def test_deprecated_distance_metric_supports(metric):
# Check that a warning is raised for all deprecated distance metric supports
clf = NearestCentroid(metric=metric)
with pytest.warns(
FutureWarning,
match="Support for distance metrics other than euclidean and manhattan",
):
clf.fit(X, y)


def test_features_zero_var():
# Test that features with 0 variance throw error

Expand Down
0