8000 FIX Revert `{Ball,KD}Tree.valid_metrics` to public class attributes (… · punndcoder28/scikit-learn@6c8dd3f · GitHub
[go: up one dir, main page]

Skip to content

Commit 6c8dd3f

Browse files
jjerphanogrisel
authored andcommitted
FIX Revert {Ball,KD}Tree.valid_metrics to public class attributes (scikit-learn#26754)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 18581b6 commit 6c8dd3f

File tree

8 files changed

+39
-34
lines changed

8 files changed

+39
-34
lines changed

doc/modules/neighbors.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ including specification of query strategies, distance metrics, etc. For a list
139139
of valid metrics use :meth:`KDTree.valid_metrics` and :meth:`BallTree.valid_metrics`:
140140

141141
>>> from sklearn.neighbors import KDTree, BallTree
142-
>>> KDTree.valid_metrics()
142+
>>> KDTree.valid_metrics
143143
['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock', 'l1', 'chebyshev', 'infinity']
144-
>>> BallTree.valid_metrics()
144+
>>> BallTree.valid_metrics
145145
['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock', 'l1', 'chebyshev', 'infinity', 'seuclidean', 'mahalanobis', 'hamming', 'canberra', 'braycurtis', 'jaccard', 'dice', 'rogerstanimoto', 'russellrao', 'sokalmichener', 'sokalsneath', 'haversine', 'pyfunc']
146146

147147
.. _classification:
@@ -480,7 +480,7 @@ A list of valid metrics for any of the above algorithms can be obtained by using
480480
``valid_metric`` attribute. For example, valid metrics for ``KDTree`` can be generated by:
481481

482482
>>> from sklearn.neighbors import KDTree
483-
>>> print(sorted(KDTree.valid_metrics()))
483+
>>> print(sorted(KDTree.valid_metrics))
484484
['chebyshev', 'cityblock', 'euclidean', 'infinity', 'l1', 'l2', 'manhattan', 'minkowski', 'p']
485485

486486

doc/whats_new/v1.3.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,23 @@
22

33
.. currentmodule:: sklearn
44

5+
.. _changes_1_3_1:
6+
7+
Version 1.3.1
8+
=============
9+
10+
**TODO: set date**
11+
12+
Changelog
13+
---------
14+
15+
:mod:`sklearn.neighbors`
16+
........................
17+
18+
- |Fix| Reintroduce :attr:`sklearn.neighbors.BallTree.valid_metrics` and
19+
:attr:`sklearn.neighbors.KDTree.valid_metrics` as public class attributes.
20+
:pr:`26754` by :user:`Julien Jerphanion <jjerphan>`.
21+
522
.. _changes_1_3:
623

724
Version 1.3.0

sklearn/cluster/_hdbscan/hdbscan.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from ._reachability import mutual_reachability_graph
5757
from ._tree import HIERARCHY_dtype, labelling_at_cut, tree_to_labels
5858

59-
FAST_METRICS = set(KDTree.valid_metrics() + BallTree.valid_metrics())
59+
FAST_METRICS = set(KDTree.valid_metrics + BallTree.valid_metrics)
6060

6161
# Encodings are arbitrary but must be strictly negative.
6262
# The current encodings are chosen as extensions to the -1 noise label.
@@ -768,14 +768,12 @@ def fit(self, X, y=None):
768768
n_jobs=self.n_jobs,
769769
**self._metric_params,
770770
)
771-
if self.algorithm == "kdtree" and self.metric not in KDTree.valid_metrics():
771+
if self.algorithm == "kdtree" and self.metric not in KDTree.valid_metrics:
772772
raise ValueError(
773773
f"{self.metric} is not a valid metric for a KDTree-based algorithm."
774774
" Please select a different metric."
775775
)
776-
elif (
777-
self.algorithm == "balltree" and self.metric not in BallTree.valid_metrics()
778-
):
776+
elif self.algorithm == "balltree" and self.metric not in BallTree.valid_metrics:
779777
raise ValueError(
780778
f"{self.metric} is not a valid metric for a BallTree-based algorithm."
781779
" Please select a different metric."
@@ -805,7 +803,7 @@ def fit(self, X, y=None):
805803
# We can't do much with sparse matrices ...
806804
mst_func = _hdbscan_brute
807805
kwargs["copy"] = self.copy
808-
elif self.metric in KDTree.valid_metrics():
806+
elif self.metric in KDTree.valid_metrics:
809807
# TODO: Benchmark KD vs Ball Tree efficiency
810808
mst_func = _hdbscan_prims
811809
kwargs["algo"] = "kd_tree"

sklearn/cluster/tests/test_hdbscan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change F438
@@ -165,7 +165,7 @@ def test_hdbscan_algorithms(algo, metric):
165165
metric_params=metric_params,
166166
)
167167

168-
if metric not in ALGOS_TREES[algo].valid_metrics():
168+
if metric not in ALGOS_TREES[algo].valid_metrics:
169169
with pytest.raises(ValueError):
170170
hdb.fit(X)
171171
elif metric == "wminkowski":
@@ -424,7 +424,7 @@ def test_hdbscan_tree_invalid_metric():
424424

425425
# The set of valid metrics for KDTree at the time of writing this test is a
426426
# strict subset of those supported in BallTree
427-
metrics_not_kd = list(set(BallTree.valid_metrics()) - set(KDTree.valid_metrics()))
427+
metrics_not_kd = list(set(BallTree.valid_metrics) - set(KDTree.valid_metrics))
428428
if len(metrics_not_kd) > 0:
429429
with pytest.raises(ValueError, match=msg):
430430
HDBSCAN(algorithm="kdtree", metric=metrics_not_kd[0]).fit(X)

sklearn/neighbors/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@
6565
SCIPY_METRICS += ["matching"]
6666

6767
VALID_METRICS = dict(
68-
ball_tree=BallTree._valid_metrics,
69-
kd_tree=KDTree._valid_metrics,
68+
ball_tree=BallTree.valid_metrics,
69+
kd_tree=KDTree.valid_metrics,
7070
# The following list comes from the
7171
# sklearn.metrics.pairwise doc string
7272
brute=sorted(set(PAIRWISE_DISTANCE_FUNCTIONS).union(SCIPY_METRICS)),

sklearn/neighbors/_binary_tree.pxi

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,10 @@ metric : str or DistanceMetric64 object, default='minkowski'
236236
Metric to use for distance computation. Default is "minkowski", which
237237
results in the standard Euclidean distance when p = 2.
238238
A list of valid metrics for {BinaryTree} is given by
239-
:meth:`{BinaryTree}.valid_metrics`.
239+
:attr:`{BinaryTree}.valid_metrics`.
240240
See the documentation of `scipy.spatial.distance
241-
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and the metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for
241+
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and
242+
the metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for
242243
more information on any distance metric.
243244
244245
Additional keywords are passed to the distance metric class.
@@ -249,6 +250,8 @@ Attributes
249250
----------
250251
data : memory view
251252
The training data
253+
valid_metrics: list of str
254+
List of valid distance metrics.
252255
253256
Examples
254257
--------
@@ -792,7 +795,7 @@ cdef class BinaryTree:
792795
cdef int n_splits
793796
cdef int n_calls
794797

795-
_valid_metrics = VALID_METRIC_IDS
798+
valid_metrics = VALID_METRIC_IDS
796799

797800
# Use cinit to initialize all arrays to empty: this will prevent memory
798801
# errors and seg-faults in rare cases where __init__ is not called
@@ -979,19 +982,6 @@ cdef class BinaryTree:
979982
self.node_bounds.base,
980983
)
981984

982-
@classmethod
983-
def valid_metrics(cls):
984-
"""Get list of valid distance metrics.
985-
986-
.. versionadded:: 1.3
987-
988-
Returns
989-
-------
990-
valid_metrics: list of str
991-
List of valid distance metrics.
992-
"""
993-
return cls._valid_metrics
994-
995985
cdef inline float64_t dist(self, float64_t* x1, float64_t* x2,
996986
intp_t size) except -1 nogil:
997987
"""Compute the distance between arrays x1 and x2"""

sklearn/neighbors/_kde.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,12 @@ def _choose_algorithm(self, algorithm, metric):
173173
# algorithm to compute the result.
174174
if algorithm == "auto":
175175
# use KD Tree if possible
176-
if metric in KDTree.valid_metrics():
176+
if metric in KDTree.valid_metrics:
177177
return "kd_tree"
178-
elif metric in BallTree.valid_metrics():
178+
elif metric in BallTree.valid_metrics:
179179
return "ball_tree"
180180
else: # kd_tree or ball_tree
181-
if metric not in TREE_DICT[algorithm].valid_metrics():
181+
if metric not in TREE_DICT[algorithm].valid_metrics:
182182
raise ValueError(
183183
"invalid metric for {0}: '{1}'".format(TREE_DICT[algorithm], metric)
184184
)

sklearn/neighbors/tests/test_kde.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_kde_algorithm_metric_choice(algorithm, metric):
113113

114114
kde = KernelDensity(algorithm=algorithm, metric=metric)
115115

116-
if algorithm == "kd_tree" and metric not in KDTree.valid_metrics():
116+
if algorithm == "kd_tree" and metric not in KDTree.valid_metrics:
117117
with pytest.raises(ValueError, match="invalid metric"):
118118
kde.fit(X)
119119
else:
@@ -164,7 +164,7 @@ def test_kde_sample_weights():
164164
test_points = rng.rand(n_samples_test, d)
165165
for algorithm in ["auto", "ball_tree", "kd_tree"]:
166166
for metric in ["euclidean", "minkowski", "manhattan", "chebyshev"]:
167-
if algorithm != "kd_tree" or metric in KDTree.valid_metrics():
167+
if algorithm != "kd_tree" or metric in KDTree.valid_metrics:
168168
kde = KernelDensity(algorithm=algorithm, metric=metric)
169169

170170
# Test that adding a constant sample weight has no effect

0 commit comments

Comments
 (0)
0