8000 ENH Allow 0 < p < 1 for Minkowski distance for `algorithm="brute"` in… · scikit-learn/scikit-learn@13b5b61 · GitHub
[go: up one dir, main page]

Skip to content

Commit 13b5b61

Browse files
RudreshVeerkharejjerphanglemaitre
authored
ENH Allow 0 < p < 1 for Minkowski distance for algorithm="brute" in NeighborsBase (#24750)
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 539cd6c commit 13b5b61

File tree

3 files changed

+92
-7
lines changed

3 files changed

+92
-7
lines changed

doc/whats_new/v1.2.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,11 @@ Changelog
530530
:pr:`10468` by :user:`Ruben <icfly2>` and :pr:`22993` by
531531
:user:`Jovan Stojanovic <jovan-stojanovic>`.
532532

533+
- |Enhancement| :class:`neighbors.NeighborsBase` now accepts
534+
Minkowski semi-metric (i.e. when :math:`0 < p < 1` for
535+
`metric="minkowski"`) for `algorithm="auto"` or `algorithm="brute"`.
536+
:pr:`24750` by :user:`Rudresh Veerkhare <RudreshVeerkhare>`
537+
533538
- |Efficiency| :class:`neighbors.NearestCentroid` is faster and requires
534539
less memory as it better leverages CPUs' caches to compute predictions.
535540
:pr:`24645` by :user:`Olivier Grisel <ogrisel>`.

sklearn/neighbors/_base.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ class NeighborsBase(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
382382
"radius": [Interval(Real, 0, None, closed="both"), None],
383383
"algorithm": [StrOptions({"auto", "ball_tree", "kd_tree", "brute"})],
384384
"leaf_size": [Interval(Integral, 1, None, closed="left")],
385-
"p": [Interval(Real, 1, None, closed="both"), None],
385+
"p": [Interval(Real, 0, None, closed="right"), None],
386386
"metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable],
387387
"metric_params": [dict, None],
388388
"n_jobs": [Integral, None],
@@ -447,12 +447,6 @@ def _check_algorithm_metric(self):
447447
SyntaxWarning,
448448
stacklevel=3,
449449
)
450-
effective_p = self.metric_params["p"]
451-
else:
452-
effective_p = self.p
453-
454-
if self.metric in ["wminkowski", "minkowski"] and effective_p < 1:
455-
raise ValueError("p must be greater or equal to one for minkowski metric")
456450

457451
def _fit(self, X, y=None):
458452
if self._get_tags()["requires_y"]:
@@ -596,6 +590,12 @@ def _fit(self, X, y=None):
596590
self._fit_method = "brute"
597591
else:
598592
if (
593+
# TODO(1.3): remove "wminkowski"
594+
self.effective_metric_ in ("wminkowski", "minkowski")
595+
and self.effective_metric_params_["p"] < 1
596+
):
597+
self._fit_method = "brute"
598+
elif (
599599
self.effective_metric_ == "minkowski"
600600
and self.effective_metric_params_.get("w") is not None
601601
):
@@ -619,6 +619,29 @@ def _fit(self, X, y=None):
619619
else:
620620
self._fit_method = "brute"
621621

622+
if (
623+
# TODO(1.3): remove "wminkowski"
624+
self.effective_metric_ in ("wminkowski", "minkowski")
625+
and self.effective_metric_params_["p"] < 1
626+
):
627+
# For 0 < p < 1 Minkowski distances aren't valid distance
628+
# metric as they do not satisfy triangular inequality:
629+
# they are semi-metrics.
630+
# algorithm="kd_tree" and algorithm="ball_tree" can't be used because
631+
# KDTree and BallTree require a proper distance metric to work properly.
632+
# Howe 8000 ver, the brute-force algorithm supports semi-metrics.
633+
if self._fit_method == "brute":
634+
warnings.warn(
635+
"Mind that for 0 < p < 1, Minkowski metrics are not distance"
636+
" metrics. Continuing the execution with `algorithm='brute'`."
637+
)
638+
else: # self._fit_method in ("kd_tree", "ball_tree")
639+
raise ValueError(
640+
f'algorithm="{self._fit_method}" does not support 0 < p < 1 for '
641+
"the Minkowski metric. To resolve this problem either "
642+
'set p >= 1 or algorithm="brute".'
643+
)
644+
622645
if self._fit_method == "ball_tree":
623646
self._tree = BallTree(
624647
X,

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,63 @@ def test_neighbors_validate_parameters(Estimator):
14981498
nbrs.predict([[]])
14991499

15001500

1501+
@pytest.mark.parametrize(
1502+
"Estimator",
1503+
[
1504+
neighbors.KNeighborsClassifier,
1505+
neighbors.RadiusNeighborsClassifier,
1506+
neighbors.KNeighborsRegressor,
1507+
neighbors.RadiusNeighborsRegressor,
1508+
],
1509+
)
1510+
@pytest.mark.parametrize("n_features", [2, 100])
1511+
@pytest.mark.parametrize("algorithm", ["auto", "brute"])
1512+
def test_neighbors_minkowski_semimetric_algo_warn(Estimator, n_features, algorithm):
1513+
"""
1514+
Validation of all classes extending NeighborsBase with
1515+
Minkowski semi-metrics (i.e. when 0 < p < 1). That proper
1516+
Warning is raised for `algorithm="auto"` and "brute".
1517+
"""
1518+
X = rng.random_sample((10, n_features))
1519+
y = np.ones(10)
1520+
1521+
model = Estimator(p=0.1, algorithm=algorithm)
1522+
msg = (
1523+
"Mind that for 0 < p < 1, Minkowski metrics are not distance"
1524+
" metrics. Continuing the execution with `algorithm='brute'`."
1525+
)
1526+
with pytest.warns(UserWarning, match=msg):
1527+
model.fit(X, y)
1528+
1529+
assert model._fit_method == "brute"
1530+
1531+
1532+
@pytest.mark.parametrize(
1533+
"Estimator",
1534+
[
1535+
neighbors.KNeighborsClassifier,
1536+
neighbors.RadiusNeighborsClassifier,
1537+
neighbors.KNeighborsRegressor,
1538+
neighbors.RadiusNeighborsRegressor,
1539+
],
1540+
)
1541+
@pytest.mark.parametrize("n_features", [2, 100])
1542+
@pytest.mark.parametrize("algorithm", ["kd_tree", "ball_tree"])
1543+
def test_neighbors_minkowski_semimetric_algo_error(Estimator, n_features, algorithm):
1544+
"""Check that we raise a proper error if `algorithm!='brute'` and `p<1`."""
1545+
X = rng.random_sample((10, 2))
1546+
y = np.ones(10)
1547+
1548+
model = Estimator(algorithm=algorithm, p=0.1)
1549+
msg = (
1550+
f'algorithm="{algorithm}" does not support 0 < p < 1 for '
1551+
"the Minkowski metric. To resolve this problem either "
1552+
'set p >= 1 or algorithm="brute".'
1553+
)
1554+
with pytest.raises(ValueError, match=msg):
1555+
model.fit(X, y)
1556+
1557+
15011558
# TODO: remove when NearestNeighbors methods uses parameter validation mechanism
15021559
def test_nearest_neighbors_validate_params():
15031560
"""Validate parameter of NearestNeighbors."""

0 commit comments

Comments
 (0)
0