8000 ENH Allow 0<p<1 for Minkowski metric by Shreesha3112 · Pull Request #26760 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Allow 0<p<1 for Minkowski metric #26760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2fcc2e2
fix Allow 0<p<1 for minkowski metric
Jul 4, 2023
fa69fcf
added change log entry
Jul 5, 2023
de7dc52
resolve merge conflict in whatsnew 1.4
Jul 6, 2023
c868d44
added scipy version check for cdist
Jul 6, 2023
1ad3f0c
Merge remote-tracking branch 'upstream/main' into fix/MinkowskiDistan…
Jul 6, 2023
a728d1e
update scipy check in test_dist_metrics
Jul 6, 2023
ef9623f
Merge remote-tracking branch 'upstream/main' into fix/MinkowskiDistan…
Jul 6, 2023
d96fc3c
removed unnecsary return in test
Jul 6, 2023
4a1e039
updated whats_new 1.4
Jul 7, 2023
9d3bc07
Merge remote-tracking branch 'upstream/main' into fix/MinkowskiDistan…
Jul 7, 2023
e9b5039
fixed suggestions from Micky774
Jul 8, 2023
8e2deb7
updated doc of Minkowski Distance
Jul 8, 2023
bf3e12d
Merge remote-tracking branch 'upstream/main' into fix/MinkowskiDistan…
Jul 8, 2023
ab73914
update changelog
Jul 8, 2023
72c7879
remved mistakenly added
Jul 8, 2023
4c3fc42
remved mistakenly added
Jul 8, 2023
0326062
Merge remote-tracking branch 'upstream/main' into fix/MinkowskiDistan…
Jul 12, 2023
a292677
update whatsnew 1.4
Jul 12, 2023
a066803
Apply suggestions from code review
Shreesha3112 Jul 12, 2023
56669e5
Merge remote-tracking branch 'upstream/main' into fix/MinkowskiDistan…
Jul 12, 2023
08f095c
update whatsnew 1.4
Jul 12, 2023
5eb1753
fix title underline too short error
Jul 13, 2023
7d9edbc
added non regresssion test for p<1 in test_neighbors
Jul 14, 2023
81b28b6
fix merge conflict in whats_new 1.4
Jul 14, 2023
e1eea90
Merge remote-tracking branch 'upstream/main' into fix/MinkowskiDistan…
Jul 14, 2023
3f79c3b
Update doc/whats_new/v1.4.rst
Shreesha3112 Jul 20, 2023
6750028
Update v1.4.rst
jeremiedbb Jul 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ Changelog
object in the parameter grid if it's an estimator. :pr:`26786` by `Adrin
Jalali`_.

:mod:`sklearn.neighbors`
........................

- |Fix| Neighbors based estimators now correctly work when `metric="minkowski"` and the
metric parameter `p` is in the range `0 < p < 1`, regardless of the `dtype` of `X`.
:pr:`26760` by :user:`Shreesha Kumar Bhat <Shreesha3112>`.

:mod:`sklearn.tree`
...................

Expand Down
18 changes: 13 additions & 5 deletions sklearn/metrics/_dist_metrics.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -1271,19 +1271,27 @@ cdef class MinkowskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):

Parameters
----------
p : int
p : float
The order of the p-norm of the difference (see above).

.. versionchanged:: 1.4.0
Minkowski distance allows `p` to be `0<p<1`.


w : (N,) array-like (optional)
The weight vector.

Minkowski Distance requires p >= 1 and finite. For p = infinity,
use ChebyshevDistance.
Minkowski Distance requires p > 0 and finite.
When :math:`p \in (0,1)`, it isn't a true metric but is permissible when
the triangular inequality isn't necessary.
For p = infinity, use ChebyshevDistance.
Note that for p=1, ManhattanDistance is more efficient, and for
p=2, EuclideanDistance is more efficient.

"""
def __init__(self, p, w=None):
if p < 1:
raise ValueError("p must be greater than 1")
if p <= 0:
raise ValueError("p must be greater than 0")
elif np.isinf(p):
raise ValueError("MinkowskiDistance requires finite p. "
"For p=inf, use ChebyshevDistance.")
Expand Down
25 changes: 22 additions & 3 deletions sklearn/metrics/tests/test_dist_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from sklearn.utils import check_random_state
from sklearn.utils._testing import assert_allclose, create_memmap_backed_data
from sklearn.utils.fixes import parse_version, sp_version


def dist_func(x1, x2, p):
Expand Down Expand Up @@ -42,18 +43,17 @@ def dist_func(x1, x2, p):
V = rng.random_sample((d, d))
VI = np.dot(V, V.T)


METRICS_DEFAULT_PARAMS = [
("euclidean", {}),
("cityblock", {}),
("minkowski", dict(p=(1, 1.5, 2, 3))),
("minkowski", dict(p=(0.5, 1, 1.5, 2, 3))),
("chebyshev", {}),
("seuclidean", dict(V=(rng.random_sample(d),))),
("mahalanobis", dict(VI=(VI,))),
("hamming", {}),
("canberra", {}),
("braycurtis", {}),
("minkowski", dict(p=(1, 1.5, 3), w=(rng.random_sample(d),))),
("minkowski", dict(p=(0.5, 1, 1.5, 3), w=(rng.random_sample(d),))),
]


Expand All @@ -76,6 +76,13 @@ def test_cdist(metric_param_grid, X, Y):
# with scipy
rtol_dict = {"rtol": 1e-6}

# TODO: Remove when scipy minimum version >= 1.7.0
# scipy supports 0<p<1 for minkowski metric >= 1.7.0
if metric == "minkowski":
p = kwargs["p"]
if sp_version < parse_version("1.7.0") and p < 1:
pytest.skip("scipy does not support 0<p<1 for minkowski metric < 1.7.0")

D_scipy_cdist = cdist(X, Y, metric, **kwargs)

dm = DistanceMetric.get_metric(metric, X.dtype, **kwargs)
Expand Down Expand Up @@ -150,6 +157,12 @@ def test_pdist(metric_param_grid, X):
# with scipy
rtol_dict = {"rtol": 1e-6}

# TODO: Remove when scipy minimum version >= 1.7.0
# scipy supports 0<p<1 for minkowski metric >= 1.7.0
if metric == "minkowski":
p = kwargs["p"]
if sp_version < parse_version("1.7.0") and p < 1:
pytest.skip("scipy does not support 0<p<1 for minkowski metric < 1.7.0")
D_scipy_pdist = cdist(X, X, metric, **kwargs)

dm = DistanceMetric.get_metric(metric, X.dtype, **kwargs)
Expand Down Expand Up @@ -397,3 +410,9 @@ def test_get_metric_bad_dtype():
msg = r"Unexpected dtype .* provided. Please select a dtype from"
with pytest.raises(ValueError, match=msg):
DistanceMetric.get_metric("manhattan", dtype)


def test_minkowski_metric_validate_bad_p_parameter():
msg = "p must be greater than 0"
with pytest.raises(ValueError, match=msg):
DistanceMetric.get_metric("minkowski", p=0)
19 changes: 19 additions & 0 deletions sklearn/neighbors/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2207,3 +2207,22 @@ def test_predict_dataframe():

knn = neighbors.KNeighborsClassifier(n_neighbors=2).fit(X, y)
knn.predict(X)


def test_nearest_neighbours_works_with_p_less_than_1():
"""Check that NearestNeighbors works with :math:`p \\in (0,1)` when `algorithm`
is `"auto"` or `"brute"` regardless of the dtype of X.

Non-regression test for issue #26548
"""
X = np.array([[1.0, 0.0], [0.0, 0.0], [0.0, 1.0]])
neigh = neighbors.NearestNeighbors(
n_neighbors=3, algorithm="brute", metric_params={"p": 0.5}
)
neigh.fit(X)

y = neigh.radius_neighbors(X[0].reshape(1, -1), radius=4, return_distance=False)
assert_allclose(y[0], [0, 1, 2])

y = neigh.kneighbors(X[0].reshape(1, -1), return_distance=False)
assert_allclose(y[0], [0, 1, 2])
0