10000 Attempted fix for #597 · scikit-learn-contrib/hdbscan@55f919e · GitHub
[go: up one dir, main page]

Skip to content

Commit 55f919e

Browse files
committed
Attempted fix for #597
1 parent ac7e6fe commit 55f919e

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

hdbscan/hdbscan_.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,22 @@
3737
from .plots import CondensedTree, SingleLinkageTree, MinimumSpanningTree
3838
from .prediction import PredictionData
3939

40-
FAST_METRICS = KDTree.valid_metrics + BallTree.valid_metrics + ["cosine", "arccos"]
40+
KDTREE_VALID_METRICS = ["euclidean", "l2", "minkowski", "p", "manhattan", "cityblock", "l1", "chebyshev", "infinity"]
41+
BALLTREE_VALID_METRICS = KDTREE_VALID_METRICS + [
42+
"braycurtis",
43+
"canberra",
44+
"dice",
45+
"hamming",
46+
"haversine",
47+
"jaccard",
48+
"mahalanobis",
49+
"rogerstanimoto",
50+
"russellrao",
51+
"seuclidean",
52+
"sokalmichener",
53+
"sokalsneath",
54+
]
55+
FAST_METRICS = KDTREE_VALID_METRICS + BALLTREE_VALID_METRICS + ["cosine", "arccos"]
4156

4257
# Author: Leland McInnes <leland.mcinnes@gmail.com>
4358
# Steve Astels <sastels@gmail.com>
@@ -742,19 +757,19 @@ def hdbscan(
742757
_hdbscan_generic
743758
)(X, min_samples, alpha, metric, p, leaf_size, gen_min_span_tree, **kwargs)
744759
elif algorithm == "prims_kdtree":
745-
if metric not in KDTree.valid_metrics:
760+
if metric not in KDTREE_VALID_METRICS:
746761
raise ValueError("Cannot use Prim's with KDTree for this" " metric!")
747762
(single_linkage_tree, result_min_span_tree) = memory.cache(
748763
_hdbscan_prims_kdtree
749764
)(X, min_samples, alpha, metric, p, leaf_size, gen_min_span_tree, **kwargs)
750765
elif algorithm == "prims_balltree":
751-
if metric not in BallTree.valid_metrics:
766+
if metric not in BALLTREE_VALID_METRICS:
752767
raise ValueError("Cannot use Prim's with BallTree for this" " metric!")
753768
(single_linkage_tree, result_min_span_tree) = memory.cache(
754769
_hdbscan_prims_balltree
755770
)(X, min_samples, alpha, metric, p, leaf_size, gen_min_span_tree, **kwargs)
756771
elif algorithm == "boruvka_kdtree":
757-
if metric not in BallTree.valid_metrics:
772+
if metric not in BALLTREE_VALID_METRICS:
758773
raise ValueError("Cannot use Boruvka with KDTree for this" " metric!")
759774
(single_linkage_tree, result_min_span_tree) = memory.cache(
760775
_hdbscan_boruvka_kdtree
@@ -771,7 +786,7 @@ def hdbscan(
771786
**kwargs
772787
)
773788
elif algorithm == "boruvka_balltree":
774-
if metric not in BallTree.valid_metrics:
789+
if metric not in BALLTREE_VALID_METRICS:
775790
raise ValueError("Cannot use Boruvka with BallTree for this" " metric!")
776791
if (X.shape[0] // leaf_size) > 16000:
777792
warn(
@@ -802,7 +817,7 @@ def hdbscan(
802817
(single_linkage_tree, result_min_span_tree) = memory.cache(
803818
_hdbscan_generic
804819
)(X, min_samples, alpha, metric, p, leaf_size, gen_min_span_tree, **kwargs)
805-
elif metric in KDTree.valid_metrics:
820+
elif metric in KDTREE_VALID_METRICS:
806821
# TO DO: Need heuristic to decide when to go to boruvka;
807822
# still debugging for now
808823
if X.shape[1] > 60:
@@ -1237,9 +1252,9 @@ def generate_prediction_data(self):
12371252

12381253
if self.metric in FAST_METRICS:
12391254
min_samples = self.min_samples or self.min_cluster_size
1240-
if self.metric in KDTree.valid_metrics:
1255+
if self.metric in KDTREE_VALID_METRICS:
12411256
tree_type = "kdtree"
1242-
elif self.metric in BallTree.valid_metrics:
1257+
elif self.metric in BALLTREE_VALID_METRICS:
12431258
tree_type = "balltree"
12441259
else:
12451260
warn("Metric {} not supported for prediction data!".format(self.metric))

0 commit comments

Comments
 (0)
0