8000 ENH Allow 0 < p < 1 for Minkowski distance for `algorithm="brute"` in `NeighborsBase` by RudreshVeerkhare · Pull Request #24750 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Allow 0 < p < 1 for Minkowski distance for algorithm="brute" in NeighborsBase #24750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
630e734
modified constraints to allow p < 1
RudreshVeerkhare Oct 24, 2022
eb991dc
Merge remote-tracking branch 'upstream/main' into neighbors_minkowski…
RudreshVeerkhare Oct 24, 2022
cdf7e6a
Added feature to conditionally allow minkowski with p < 1
RudreshVeerkhare Oct 24, 2022
40a1932
Merge branch 'scikit-learn:main' into neighbors_minkowski_calculation…
RudreshVeerkhare Oct 24, 2022
81b8fa0
Merge branch 'scikit-learn:main' into neighbors_minkowski_calculation…
RudreshVeerkhare Oct 25, 2022
956f470
for lgorithm=auto set _fit_method=brute when p < 1
RudreshVeerkhare Oct 25, 2022
8cd4007
black formatting
RudreshVeerkhare Oct 25, 2022
f38bb21
test added for validation of NeighborsBase for minkowski with p < 1
8000 RudreshVeerkhare Oct 25, 2022
d5d5dea
modified error msg to include suggestion
RudreshVeerkhare Oct 25, 2022
06c1634
Update sklearn/neighbors/_base.py
RudreshVeerkhare Oct 26, 2022
f88f424
Apply suggestions from code review
RudreshVeerkhare Oct 26, 2022
f4f48b5
balck formatting
RudreshVeerkhare Oct 26, 2022
a8a7c9d
modified tests accoring to suggestions
RudreshVeerkhare Oct 26, 2022
24e6dc8
added enhancement changelog to v1.2
RudreshVeerkhare Oct 26, 2022
6eefd85
Merge branch 'scikit-learn:main' into neighbors_minkowski_calculation…
RudreshVeerkhare Oct 26, 2022
548fe87
Update doc/whats_new/v1.2.rst
RudreshVeerkhare Oct 26, 2022
cfcfc0d
Update sklearn/neighbors/_base.py
RudreshVeerkhare Oct 26, 2022
0d1b5c6
modified _parameter_constraints to exclude p=0
RudreshVeerkhare Oct 26, 2022
e22e811
added test to validate minkowski with p=0
RudreshVeerkhare Oct 26, 2022
a06addc
Update sklearn/neighbors/_base.py
RudreshVeerkhare Oct 26, 2022
160b842
Update sklearn/neighbors/_base.py
RudreshVeerkhare Oct 26, 2022
a7a27e4
Merge branch 'scikit-learn:main' into neighbors_minkowski_calculation…
RudreshVeerkhare Oct 26, 2022
49f2978
Merge branch 'scikit-learn:main' into neighbors_minkowski_calculation…
RudreshVeerkhare Oct 29, 2022
4d57b15
Merge branch 'scikit-learn:main' into neighbors_minkowski_calculation…
RudreshVeerkhare Nov 1, 2022
7eca748
Merge branch 'scikit-learn:main' into neighbors_minkowski_calculation…
RudreshVeerkhare Nov 1, 2022
9044c3f
Merge branch 'scikit-learn:main' into neighbors_minkowski_calculation…
RudreshVeerkhare Nov 2, 2022
cc47bac
Update sklearn/neighbors/_base.py
RudreshVeerkhare Nov 3, 2022
65b033d
Apply suggestions from code review
RudreshVeerkhare Nov 3, 2022
9151b1c
Update sklearn/neighbors/_base.py
RudreshVeerkhare Nov 3, 2022
ab9b0da
Merge branch 'main' into neighbors_minkowski_calculation_exception
RudreshVeerkhare Nov 3, 2022
90d5db5
modified error and warning messages according to review
RudreshVeerkhare Nov 5, 2022
670c8d3
rearranged tests to use parameterization
RudreshVeerkhare Nov 5, 2022
dfd6ada
Update sklearn/neighbors/_base.py
RudreshVeerkhare Nov 5, 2022
082582b
Merge branch 'main' into neighbors_minkowski_calculation_exception
RudreshVeerkhare Nov 5, 2022
1e9b9f4
spelling fix
RudreshVeerkhare Nov 5, 2022
e42c5c2
Merge branch 'main' into neighbors_minkowski_calculation_exception
RudreshVeerkhare Nov 5, 2022
0b66385
Merge branch 'main' into neighbors_minkowski_calculation_exception
RudreshVeerkhare Nov 7, 2022
92e6200
Apply suggestions from code review
glemaitre Nov 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,11 @@ Changelog
:pr:`10468` by :user:`Ruben <icfly2>` and :pr:`22993` by
:user:`Jovan Stojanovic <jovan-stojanovic>`.

- |Enhancement| :class:`neighbors.NeighborsBase` now accepts
Minkowski semi-metric (i.e. when :math:`0 < p < 1` for
`metric="minkowski"`) for `algorithm="auto"` or `algorithm="brute"`.
:pr:`24750` by :user:`Rudresh Veerkhare <RudreshVeerkhare>`

- |Efficiency| :class:`neighbors.NearestCentroid` is faster and requires
less memory as it better leverages CPUs' caches to compute predictions.
:pr:`24645` by :user:`Olivier Grisel <ogrisel>`.
Expand Down
37 changes: 30 additions & 7 deletions sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ class NeighborsBase(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
"radius": [Interval(Real, 0, None, closed="both"), None],
"algorithm": [StrOptions({"auto", "ball_tree", "kd_tree", "brute"})],
"leaf_size": [Interval(Integral, 1, None, closed="left")],
"p": [Interval(Real, 1, None, closed="both"), None],
"p": [Interval(Real, 0, None, closed="right"), None],
"metric": [StrOptions(set(itertools.chain(*VALID_METRICS.values()))), callable],
"metric_params": [dict, None],
"n_jobs": [Integral, None],
Expand Down Expand Up @@ -447,12 +447,6 @@ def _check_algorithm_metric(self):
SyntaxWarning,
stacklevel=3,
)
effective_p = self.metric_params["p"]
else:
effective_p = self.p

if self.metric in ["wminkowski", "minkowski"] and effective_p < 1:
raise ValueError("p must be greater or equal to one for minkowski metric")

def _fit(self, X, y=None):
if self._get_tags()["requires_y"]:
Expand Down Expand Up @@ -596,6 +590,12 @@ def _fit(self, X, y=None):
self._fit_method = "brute"
else:
if (
# TODO(1.3): remove "wminkowski"
self.effective_metric_ in ("wminkowski", "minkowski")
and self.effective_metric_params_["p"] < 1
):
self._fit_method = "brute"
elif (
self.effective_metric_ == "minkowski"
and self.effective_metric_params_.get("w") is not None
):
Expand All @@ -619,6 +619,29 @@ def _fit(self, X, y=None):
else:
self._fit_method = "brute"

if (
# TODO(1.3): remove "wminkowski"
self.effective_metric_ in ("wminkowski", "minkowski")
and self.effective_metric_params_["p"] < 1
):
# For 0 < p < 1 Minkowski distances aren't valid distance
# metric as they do not satisfy triangular inequality:
# they are semi-metrics.
# algorithm="kd_tree" and algorithm="ball_tree" can't be used because
# KDTree and BallTree require a proper distance metric to work properly.
# However, the brute-force algorithm supports semi-metrics.
if self._fit_method == "brute":
warnings.warn(
"Mind that for 0 < p < 1, Minkowski metrics are not distance"
" metrics. Continuing the execution with `algorithm='brute'`."
)
else: # self._fit_method in ("kd_tree", "ball_tree")
raise ValueError(
f'algorithm="{self._fit_method}" does not support 0 < p < 1 for '
"the Minkowski metric. To resolve this problem either "
'set p >= 1 or algorithm="brute".'
)

if self._fit_method == "ball_tree":
self._tree = BallTree(
X,
Expand Down
57 changes: 57 additions & 0 deletions sklearn/neighbors/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,63 @@ def test_neighbors_validate_parameters(Estimator):
nbrs.predict([[]])


@pytest.mark.parametrize(
"Estimator",
[
neighbors.KNeighborsClassifier,
neighbors.RadiusNeighborsClassifier,
neighbors.KNeighborsRegressor,
neighbors.RadiusNeighborsRegressor,
],
)
@pytest.mark.parametrize("n_features", [2, 100])
@pytest.mark.parametrize("algorithm", ["auto", "brute"])
def test_neighbors_minkowski_semimetric_algo_warn(Estimator, n_features, algorithm):
"""
Validation of all classes extending NeighborsBase with
Minkowski semi-metrics (i.e. when 0 < p < 1). That proper
Warning is raised for `algorithm="auto"` and "brute".
"""
X = rng.random_sample((10, n_features))
y = np.ones(10)

model = Estimator(p=0.1, algorithm=algorithm)
msg = (
"Mind that for 0 < p < 1, Minkowski metrics are not distance"
" metrics. Continuing the execution with `algorithm='brute'`."
)
with pytest.warns(UserWarning, match=msg):
model.fit(X, y)

assert model._fit_method == "brute"


@pytest.mark.parametrize(
"Estimator",
[
neighbors.KNeighborsClassifier,
neighbors.RadiusNeighborsClassifier,
neighbors.KNeighborsRegressor,
neighbors.RadiusNeighborsRegressor,
],
)
@pytest.mark.parametrize("n_features", [2, 100])
@pytest.mark.parametrize("algorithm", ["kd_tree", "ball_tree"])
def test_neighbors_minkowski_semimetric_algo_error(Estimator, n_features, algorithm):
"""Check that we raise a proper error if `algorithm!='brute'` and `p<1`."""
X = rng.random_sample((10, 2))
y = np.ones(10)

model = Estimator(algorithm=algorithm, p=0.1)
msg = (
f'algorithm="{algorithm}" does not support 0 < p < 1 for '
"the Minkowski metric. To resolve this problem either "
'set p >= 1 or algorithm="brute".'
)
with pytest.raises(ValueError, match=msg):
model.fit(X, y)


# TODO: remove when NearestNeighbors methods uses parameter validation mechanism
def test_nearest_neighbors_validate_params():
"""Validate parameter of NearestNeighbors."""
Expand Down
0