From 00f0b60b1c8c98de19506c4345aa7d61e15c0105 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 24 Jul 2023 10:42:18 -0400 Subject: [PATCH 1/4] Minor iter -- partial progress --- sklearn/cluster/_hdbscan/_linkage.pyx | 9 +++++---- sklearn/cluster/_hdbscan/hdbscan.py | 2 +- sklearn/cluster/tests/test_hdbscan.py | 8 ++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sklearn/cluster/_hdbscan/_linkage.pyx b/sklearn/cluster/_hdbscan/_linkage.pyx index 03e91ac8d6833..bbddd826a8aa9 100644 --- a/sklearn/cluster/_hdbscan/_linkage.pyx +++ b/sklearn/cluster/_hdbscan/_linkage.pyx @@ -33,9 +33,10 @@ cimport numpy as cnp from libc.float cimport DBL_MAX +from cython cimport floating import numpy as np -from ...metrics._dist_metrics cimport DistanceMetric64 +from ...metrics._dist_metrics cimport DistanceMetric from ...cluster._hierarchical_fast cimport UnionFind from ...cluster._hdbscan._tree cimport HIERARCHY_t from ...cluster._hdbscan._tree import HIERARCHY_dtype @@ -109,9 +110,9 @@ cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_mutual_reachability( cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix( - const float64_t[:, ::1] raw_data, - const float64_t[::1] core_distances, - DistanceMetric64 dist_metric, + const floating[:, ::1] raw_data, + const floating[::1] core_distances, + DistanceMetric dist_metric, float64_t alpha=1.0 ): """Compute the Minimum Spanning Tree (MST) representation of the mutual- diff --git a/sklearn/cluster/_hdbscan/hdbscan.py b/sklearn/cluster/_hdbscan/hdbscan.py index 57de8962250b1..5c08b6a3c09f2 100644 --- a/sklearn/cluster/_hdbscan/hdbscan.py +++ b/sklearn/cluster/_hdbscan/hdbscan.py @@ -701,7 +701,7 @@ def fit(self, X, y=None): X, accept_sparse=["csr", "lil"], force_all_finite=False, - dtype=np.float64, + dtype=(np.float64, np.float32), ) self._raw_data = X all_finite = True diff --git a/sklearn/cluster/tests/test_hdbscan.py b/sklearn/cluster/tests/test_hdbscan.py index c0c281ce31475..03c6af6505400 100644 --- a/sklearn/cluster/tests/test_hdbscan.py +++ b/sklearn/cluster/tests/test_hdbscan.py @@ -69,13 +69,13 @@ def test_outlier_data(outlier_type): assert_array_equal(clean_model.labels_, model.labels_[clean_indices]) -def test_hdbscan_distance_matrix(): +def test_hdbscan_distance_matrix(global_dtype): """ Tests that HDBSCAN works with precomputed distance matrices, and throws the appropriate errors when needed. """ D = euclidean_distances(X) - D_original = D.copy() + D_original = D.copy().astype(global_dtype) labels = HDBSCAN(metric="precomputed", copy=True).fit_predict(D) assert_allclose(D, D_original) @@ -118,12 +118,12 @@ def test_hdbscan_sparse_distance_matrix(sparse_constructor): assert n_clusters == n_clusters_true -def test_hdbscan_feature_array(): +def test_hdbscan_feature_array(global_dtype): """ Tests that HDBSCAN works with feature array, including an arbitrary goodness of fit check. Note that the check is a simple heuristic. """ - labels = HDBSCAN().fit_predict(X) + labels = HDBSCAN().fit_predict(X.astype(global_dtype)) n_clusters = len(set(labels) - OUTLIER_SET) assert n_clusters == n_clusters_true From 80de242991e4d922e83ba1dc12733f689f0502c5 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 24 Jul 2023 12:21:02 -0400 Subject: [PATCH 2/4] Initial templating --- .gitignore | 1 + setup.py | 2 +- sklearn/cluster/_hdbscan/_linkage.pyx | 131 +++++++++- sklearn/cluster/_hdbscan/_linkage.pyx.tp | 295 +++++++++++++++++++++++ 4 files changed, 425 insertions(+), 4 deletions(-) create mode 100644 sklearn/cluster/_hdbscan/_linkage.pyx.tp diff --git a/.gitignore b/.gitignore index f4601a15655a5..91efdf3d9dfeb 100644 --- a/.gitignore +++ b/.gitignore @@ -99,6 +99,7 @@ sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx +sklearn/cluster/_hdbscan/linkage.pyx # Default JupyterLite content jupyterlite_contents diff --git a/setup.py b/setup.py index 5af738f5f841f..d081b384c8171 100755 --- a/setup.py +++ b/setup.py @@ -209,7 +209,7 @@ def check_package_status(package, min_version): {"sources": ["_k_means_minibatch.pyx"], "include_np": True}, ], "cluster._hdbscan": [ - {"sources": ["_linkage.pyx"], "include_np": True}, + {"sources": ["_linkage.pyx.tp"], "include_np": True}, {"sources": ["_reachability.pyx"], "include_np": True}, {"sources": ["_tree.pyx"], "include_np": True}, ], diff --git a/sklearn/cluster/_hdbscan/_linkage.pyx b/sklearn/cluster/_hdbscan/_linkage.pyx index bbddd826a8aa9..831e635c38e07 100644 --- a/sklearn/cluster/_hdbscan/_linkage.pyx +++ b/sklearn/cluster/_hdbscan/_linkage.pyx @@ -1,3 +1,7 @@ +# WARNING: Do not edit this file directly. +# It is automatically generated from 'sklearn/cluster/_hdbscan/_linkage.pyx.tp'. +# Changes must be made there. + # Minimum spanning tree single linkage implementation for hdbscan # Authors: Leland McInnes # Steve Astels @@ -36,11 +40,11 @@ from libc.float cimport DBL_MAX from cython cimport floating import numpy as np -from ...metrics._dist_metrics cimport DistanceMetric +from ...metrics._dist_metrics cimport DistanceMetric, DistanceMetric32, DistanceMetric64 from ...cluster._hierarchical_fast cimport UnionFind from ...cluster._hdbscan._tree cimport HIERARCHY_t from ...cluster._hdbscan._tree import HIERARCHY_dtype -from ...utils._typedefs cimport intp_t, float64_t, int64_t, uint8_t +from ...utils._typedefs cimport float32_t, float64_t, intp_t, int64_t, uint8_t cdef extern from "numpy/arrayobject.h": intp_t * PyArray_SHAPE(cnp.PyArrayObject *) @@ -108,12 +112,133 @@ cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_mutual_reachability( return mst - cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix( const floating[:, ::1] raw_data, const floating[::1] core_distances, DistanceMetric dist_metric, float64_t alpha=1.0 +): + if floating is double: + return mst_from_data_matrix64(raw_data, core_distances, dist_metric, alpha) + else: + return mst_from_data_matrix32(raw_data, core_distances, dist_metric, alpha) + +cdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix64( + const float64_t[:, ::1] raw_data, + const float64_t[::1] core_distances, + DistanceMetric64 dist_metric, + float64_t alpha=1.0 +): + """Compute the Minimum Spanning Tree (MST) representation of the mutual- + reachability graph generated from the provided `raw_data` and + `core_distances` using Prim's algorithm. + + Parameters + ---------- + raw_data : ndarray of shape (n_samples, n_features) + Input array of data samples. + + core_distances : ndarray of shape (n_samples,) + An array containing the core-distance calculated for each corresponding + sample. + + dist_metric : DistanceMetric + The distance metric to use when calculating pairwise distances for + determining mutual-reachability. + + Returns + ------- + mst : ndarray of shape (n_samples - 1,), dtype=MST_edge_dtype + The MST representation of the mutual-reahability graph. The MST is + represented as a collecteion of edges. + """ + + cdef: + uint8_t[::1] in_tree + float64_t[::1] min_reachability + int64_t[::1] current_sources + cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst + + int64_t current_node, source_node, new_node, next_node_source + int64_t i, j, n_samples, num_features + + float64_t current_node_core_dist, new_reachability, mutual_reachability_distance + float64_t next_node_min_reach, pair_distance, next_node_core_dist + + n_samples = raw_data.shape[0] + num_features = raw_data.shape[1] + + mst = np.empty(n_samples - 1, dtype=MST_edge_dtype) + + in_tree = np.zeros(n_samples, dtype=np.uint8) + min_reachability = np.full(n_samples, fill_value=np.infty, dtype=np.float64) + current_sources = np.ones(n_samples, dtype=np.int64) + + current_node = 0 + + for i in range(0, n_samples - 1): + + in_tree[current_node] = 1 + + current_node_core_dist = core_distances[current_node] + + new_reachability = DBL_MAX + source_node = 0 + new_node = 0 + + for j in range(n_samples): + if in_tree[j]: + continue + + next_node_min_reach = min_reachability[j] + next_node_source = current_sources[j] + + pair_distance = dist_metric.dist( + &raw_data[current_node, 0], + &raw_data[j, 0], + num_features + ) + + pair_distance /= alpha + + next_node_core_dist = core_distances[j] + mutual_reachability_distance = max( + current_node_core_dist, + next_node_core_dist, + pair_distance + ) + if mutual_reachability_distance > next_node_min_reach: + if next_node_min_reach < new_reachability: + new_reachability = next_node_min_reach + source_node = next_node_source + new_node = j + continue + + if mutual_reachability_distance < next_node_min_reach: + min_reachability[j] = mutual_reachability_distance + current_sources[j] = current_node + if mutual_reachability_distance < new_reachability: + new_reachability = mutual_reachability_distance + source_node = current_node + new_node = j + else: + if next_node_min_reach < new_reachability: + new_reachability = next_node_min_reach + source_node = next_node_source + new_node = j + + mst[i].current_node = source_node + mst[i].next_node = new_node + mst[i].distance = new_reachability + current_node = new_node + + return mst + +cdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix32( + const float32_t[:, ::1] raw_data, + const float32_t[::1] core_distances, + DistanceMetric32 dist_metric, + float64_t alpha=1.0 ): """Compute the Minimum Spanning Tree (MST) representation of the mutual- reachability graph generated from the provided `raw_data` and diff --git a/sklearn/cluster/_hdbscan/_linkage.pyx.tp b/sklearn/cluster/_hdbscan/_linkage.pyx.tp new file mode 100644 index 0000000000000..63f0c4c1aabe1 --- /dev/null +++ b/sklearn/cluster/_hdbscan/_linkage.pyx.tp @@ -0,0 +1,295 @@ +# Minimum spanning tree single linkage implementation for hdbscan +# Authors: Leland McInnes +# Steve Astels +# Meekail Zain +# Copyright (c) 2015, Leland McInnes +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software without +# specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +{{py: + +implementation_specific_values = [ + # Values are the following ones: + # + # name_suffix, INPUT_DTYPE_t, INPUT_DTYPE + ('64', 'float64_t', 'np.float64'), + ('32', 'float32_t', 'np.float32') +] + +}} + +cimport numpy as cnp +from libc.float cimport DBL_MAX +from cython cimport floating + +import numpy as np +from ...metrics._dist_metrics cimport DistanceMetric, DistanceMetric32, DistanceMetric64 +from ...cluster._hierarchical_fast cimport UnionFind +from ...cluster._hdbscan._tree cimport HIERARCHY_t +from ...cluster._hdbscan._tree import HIERARCHY_dtype +from ...utils._typedefs cimport float32_t, float64_t, intp_t, int64_t, uint8_t + +cdef extern from "numpy/arrayobject.h": + intp_t * PyArray_SHAPE(cnp.PyArrayObject *) + +# Numpy structured dtype representing a single ordered edge in Prim's algorithm +MST_edge_dtype = np.dtype([ + ("current_node", np.int64), + ("next_node", np.int64), + ("distance", np.float64), +]) + +# Packed shouldn't make a difference since they're all 8-byte quantities, +# but it's included just to be safe. +ctypedef packed struct MST_edge_t: + int64_t current_node + int64_t next_node + float64_t distance + +cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_mutual_reachability( + cnp.ndarray[float64_t, ndim=2] mutual_reachability +): + """Compute the Minimum Spanning Tree (MST) representation of the mutual- + reachability graph using Prim's algorithm. + + Parameters + ---------- + mutual_reachability : ndarray of shape (n_samples, n_samples) + Array of mutual-reachabilities between samples. + + Returns + ------- + mst : ndarray of shape (n_samples - 1,), dtype=MST_edge_dtype + The MST representation of the mutual-reahability graph. The MST is + represented as a collecteion of edges. + """ + cdef: + # Note: we utilize ndarray's over memory-views to make use of numpy + # binary indexing and sub-selection below. + cnp.ndarray[int64_t, ndim=1, mode='c'] current_labels + cnp.ndarray[float64_t, ndim=1, mode='c'] min_reachability, left, right + cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst + + cnp.ndarray[uint8_t, mode='c'] label_filter + + int64_t n_samples = PyArray_SHAPE( mutual_reachability)[0] + int64_t current_node, new_node_index, new_node, i + + mst = np.empty(n_samples - 1, dtype=MST_edge_dtype) + current_labels = np.arange(n_samples, dtype=np.int64) + current_node = 0 + min_reachability = np.full(n_samples, fill_value=np.infty, dtype=np.float64) + for i in range(0, n_samples - 1): + label_filter = current_labels != current_node + current_labels = current_labels[label_filter] + left = min_reachability[label_filter] + right = mutual_reachability[current_node][current_labels] + min_reachability = np.minimum(left, right) + + new_node_index = np.argmin(min_reachability) + new_node = current_labels[new_node_index] + mst[i].current_node = current_node + mst[i].next_node = new_node + mst[i].distance = min_reachability[new_node_index] + current_node = new_node + + return mst + +cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix( + const floating[:, ::1] raw_data, + const floating[::1] core_distances, + DistanceMetric dist_metric, + float64_t alpha=1.0 +): + if floating is double: + return mst_from_data_matrix64(raw_data, core_distances, dist_metric, alpha) + else: + return mst_from_data_matrix32(raw_data, core_distances, dist_metric, alpha) + +{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}} + +cdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix{{name_suffix}}( + const {{INPUT_DTYPE_t}}[:, ::1] raw_data, + const {{INPUT_DTYPE_t}}[::1] core_distances, + DistanceMetric{{name_suffix}} dist_metric, + float64_t alpha=1.0 +): + """Compute the Minimum Spanning Tree (MST) representation of the mutual- + reachability graph generated from the provided `raw_data` and + `core_distances` using Prim's algorithm. + + Parameters + ---------- + raw_data : ndarray of shape (n_samples, n_features) + Input array of data samples. + + core_distances : ndarray of shape (n_samples,) + An array containing the core-distance calculated for each corresponding + sample. + + dist_metric : DistanceMetric + The distance metric to use when calculating pairwise distances for + determining mutual-reachability. + + Returns + ------- + mst : ndarray of shape (n_samples - 1,), dtype=MST_edge_dtype + The MST representation of the mutual-reahability graph. The MST is + represented as a collecteion of edges. + """ + + cdef: + uint8_t[::1] in_tree + float64_t[::1] min_reachability + int64_t[::1] current_sources + cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst + + int64_t current_node, source_node, new_node, next_node_source + int64_t i, j, n_samples, num_features + + float64_t current_node_core_dist, new_reachability, mutual_reachability_distance + float64_t next_node_min_reach, pair_distance, next_node_core_dist + + n_samples = raw_data.shape[0] + num_features = raw_data.shape[1] + + mst = np.empty(n_samples - 1, dtype=MST_edge_dtype) + + in_tree = np.zeros(n_samples, dtype=np.uint8) + min_reachability = np.full(n_samples, fill_value=np.infty, dtype=np.float64) + current_sources = np.ones(n_samples, dtype=np.int64) + + current_node = 0 + + for i in range(0, n_samples - 1): + + in_tree[current_node] = 1 + + current_node_core_dist = core_distances[current_node] + + new_reachability = DBL_MAX + source_node = 0 + new_node = 0 + + for j in range(n_samples): + if in_tree[j]: + continue + + next_node_min_reach = min_reachability[j] + next_node_source = current_sources[j] + + pair_distance = dist_metric.dist( + &raw_data[current_node, 0], + &raw_data[j, 0], + num_features + ) + + pair_distance /= alpha + + next_node_core_dist = core_distances[j] + mutual_reachability_distance = max( + current_node_core_dist, + next_node_core_dist, + pair_distance + ) + if mutual_reachability_distance > next_node_min_reach: + if next_node_min_reach < new_reachability: + new_reachability = next_node_min_reach + source_node = next_node_source + new_node = j + continue + + if mutual_reachability_distance < next_node_min_reach: + min_reachability[j] = mutual_reachability_distance + current_sources[j] = current_node + if mutual_reachability_distance < new_reachability: + new_reachability = mutual_reachability_distance + source_node = current_node + new_node = j + else: + if next_node_min_reach < new_reachability: + new_reachability = next_node_min_reach + source_node = next_node_source + new_node = j + + mst[i].current_node = source_node + mst[i].next_node = new_node + mst[i].distance = new_reachability + current_node = new_node + + return mst +{{endfor}} + +cpdef cnp.ndarray[HIERARCHY_t, ndim=1, mode="c"] make_single_linkage(const MST_edge_t[::1] mst): + """Construct a single-linkage tree from an MST. + + Parameters + ---------- + mst : ndarray of shape (n_samples - 1,), dtype=MST_edge_dtype + The MST representation of the mutual-reahability graph. The MST is + represented as a collecteion of edges. + + Returns + ------- + single_linkage : ndarray of shape (n_samples - 1,), dtype=HIERARCHY_dtype + The single-linkage tree tree (dendrogram) built from the MST. Each + of the array represents the following: + + - left node/cluster + - right node/cluster + - distance + - new cluster size + """ + cdef: + cnp.ndarray[HIERARCHY_t, ndim=1, mode="c"] single_linkage + + # Note mst.shape[0] is one fewer than the number of samples + int64_t n_samples = mst.shape[0] + 1 + intp_t current_node_cluster, next_node_cluster + int64_t current_node, next_node, i + float64_t distance + UnionFind U = UnionFind(n_samples) + + single_linkage = np.zeros(n_samples - 1, dtype=HIERARCHY_dtype) + + for i in range(n_samples - 1): + + current_node = mst[i].current_node + next_node = mst[i].next_node + distance = mst[i].distance + + current_node_cluster = U.fast_find(current_node) + next_node_cluster = U.fast_find(next_node) + + single_linkage[i].left_node = current_node_cluster + single_linkage[i].right_node = next_node_cluster + single_linkage[i].value = distance + single_linkage[i].cluster_size = U.size[current_node_cluster] + U.size[next_node_cluster] + + U.union(current_node_cluster, next_node_cluster) + + return single_linkage From 2273fe5bddda52b69b5b15a23a0b5225874c4047 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 24 Jul 2023 12:45:58 -0400 Subject: [PATCH 3/4] Updated validation paths and added TODO --- sklearn/cluster/_hdbscan/hdbscan.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn/cluster/_hdbscan/hdbscan.py b/sklearn/cluster/_hdbscan/hdbscan.py index 5c08b6a3c09f2..988b63d2634bc 100644 --- a/sklearn/cluster/_hdbscan/hdbscan.py +++ b/sklearn/cluster/_hdbscan/hdbscan.py @@ -337,10 +337,10 @@ def _hdbscan_prims( n_jobs=n_jobs, p=None, ).fit(X) - + # TODO: Resume when {KD, Ball}Tree support 32-bit neighbors_distances, _ = nbrs.kneighbors(X, min_samples, return_distance=True) core_distances = np.ascontiguousarray(neighbors_distances[:, -1]) - dist_metric = DistanceMetric.get_metric(metric, **metric_params) + dist_metric = DistanceMetric.get_metric(metric, dtype=X.dtype, **metric_params) # Mutual reachability distance is implicit in mst_from_data_matrix min_spanning_tree = mst_from_data_matrix(X, core_distances, dist_metric, alpha) @@ -735,7 +735,7 @@ def fit(self, X, y=None): X = self._validate_data( X, accept_sparse=["csr", "lil"], - dtype=np.float64, + dtype=(np.float64, np.float32), ) else: # Only non-sparse, precomputed distance matrices are handled here @@ -743,7 +743,9 @@ def fit(self, X, y=None): # Perform data validation after removing infinite values (numpy.inf) # from the given distance matrix. - X = self._validate_data(X, force_all_finite=False, dtype=np.float64) + X = self._validate_data( + X, force_all_finite=False, dtype=(np.float64, np.float32) + ) if np.isnan(X).any(): # TODO: Support np.nan in Cython implementation for precomputed # dense HDBSCAN From 0a8e9fd405612b00746c1824ff44a75f98550748 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Mon, 24 Jul 2023 13:39:35 -0400 Subject: [PATCH 4/4] Removed now-templated file --- sklearn/cluster/_hdbscan/_linkage.pyx | 396 -------------------------- 1 file changed, 396 deletions(-) delete mode 100644 sklearn/cluster/_hdbscan/_linkage.pyx diff --git a/sklearn/cluster/_hdbscan/_linkage.pyx b/sklearn/cluster/_hdbscan/_linkage.pyx deleted file mode 100644 index 831e635c38e07..0000000000000 --- a/sklearn/cluster/_hdbscan/_linkage.pyx +++ /dev/null @@ -1,396 +0,0 @@ -# WARNING: Do not edit this file directly. -# It is automatically generated from 'sklearn/cluster/_hdbscan/_linkage.pyx.tp'. -# Changes must be made there. - -# Minimum spanning tree single linkage implementation for hdbscan -# Authors: Leland McInnes -# Steve Astels -# Meekail Zain -# Copyright (c) 2015, Leland McInnes -# All rights reserved. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: - -# 1. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. - -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. - -# 3. Neither the name of the copyright holder nor the names of its contributors -# may be used to endorse or promote products derived from this software without -# specific prior written permission. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. - -cimport numpy as cnp -from libc.float cimport DBL_MAX -from cython cimport floating - -import numpy as np -from ...metrics._dist_metrics cimport DistanceMetric, DistanceMetric32, DistanceMetric64 -from ...cluster._hierarchical_fast cimport UnionFind -from ...cluster._hdbscan._tree cimport HIERARCHY_t -from ...cluster._hdbscan._tree import HIERARCHY_dtype -from ...utils._typedefs cimport float32_t, float64_t, intp_t, int64_t, uint8_t - -cdef extern from "numpy/arrayobject.h": - intp_t * PyArray_SHAPE(cnp.PyArrayObject *) - -# Numpy structured dtype representing a single ordered edge in Prim's algorithm -MST_edge_dtype = np.dtype([ - ("current_node", np.int64), - ("next_node", np.int64), - ("distance", np.float64), -]) - -# Packed shouldn't make a difference since they're all 8-byte quantities, -# but it's included just to be safe. -ctypedef packed struct MST_edge_t: - int64_t current_node - int64_t next_node - float64_t distance - -cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_mutual_reachability( - cnp.ndarray[float64_t, ndim=2] mutual_reachability -): - """Compute the Minimum Spanning Tree (MST) representation of the mutual- - reachability graph using Prim's algorithm. - - Parameters - ---------- - mutual_reachability : ndarray of shape (n_samples, n_samples) - Array of mutual-reachabilities between samples. - - Returns - ------- - mst : ndarray of shape (n_samples - 1,), dtype=MST_edge_dtype - The MST representation of the mutual-reahability graph. The MST is - represented as a collecteion of edges. - """ - cdef: - # Note: we utilize ndarray's over memory-views to make use of numpy - # binary indexing and sub-selection below. - cnp.ndarray[int64_t, ndim=1, mode='c'] current_labels - cnp.ndarray[float64_t, ndim=1, mode='c'] min_reachability, left, right - cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst - - cnp.ndarray[uint8_t, mode='c'] label_filter - - int64_t n_samples = PyArray_SHAPE( mutual_reachability)[0] - int64_t current_node, new_node_index, new_node, i - - mst = np.empty(n_samples - 1, dtype=MST_edge_dtype) - current_labels = np.arange(n_samples, dtype=np.int64) - current_node = 0 - min_reachability = np.full(n_samples, fill_value=np.infty, dtype=np.float64) - for i in range(0, n_samples - 1): - label_filter = current_labels != current_node - current_labels = current_labels[label_filter] - left = min_reachability[label_filter] - right = mutual_reachability[current_node][current_labels] - min_reachability = np.minimum(left, right) - - new_node_index = np.argmin(min_reachability) - new_node = current_labels[new_node_index] - mst[i].current_node = current_node - mst[i].next_node = new_node - mst[i].distance = min_reachability[new_node_index] - current_node = new_node - - return mst - -cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix( - const floating[:, ::1] raw_data, - const floating[::1] core_distances, - DistanceMetric dist_metric, - float64_t alpha=1.0 -): - if floating is double: - return mst_from_data_matrix64(raw_data, core_distances, dist_metric, alpha) - else: - return mst_from_data_matrix32(raw_data, core_distances, dist_metric, alpha) - -cdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix64( - const float64_t[:, ::1] raw_data, - const float64_t[::1] core_distances, - DistanceMetric64 dist_metric, - float64_t alpha=1.0 -): - """Compute the Minimum Spanning Tree (MST) representation of the mutual- - reachability graph generated from the provided `raw_data` and - `core_distances` using Prim's algorithm. - - Parameters - ---------- - raw_data : ndarray of shape (n_samples, n_features) - Input array of data samples. - - core_distances : ndarray of shape (n_samples,) - An array containing the core-distance calculated for each corresponding - sample. - - dist_metric : DistanceMetric - The distance metric to use when calculating pairwise distances for - determining mutual-reachability. - - Returns - ------- - mst : ndarray of shape (n_samples - 1,), dtype=MST_edge_dtype - The MST representation of the mutual-reahability graph. The MST is - represented as a collecteion of edges. - """ - - cdef: - uint8_t[::1] in_tree - float64_t[::1] min_reachability - int64_t[::1] current_sources - cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst - - int64_t current_node, source_node, new_node, next_node_source - int64_t i, j, n_samples, num_features - - float64_t current_node_core_dist, new_reachability, mutual_reachability_distance - float64_t next_node_min_reach, pair_distance, next_node_core_dist - - n_samples = raw_data.shape[0] - num_features = raw_data.shape[1] - - mst = np.empty(n_samples - 1, dtype=MST_edge_dtype) - - in_tree = np.zeros(n_samples, dtype=np.uint8) - min_reachability = np.full(n_samples, fill_value=np.infty, dtype=np.float64) - current_sources = np.ones(n_samples, dtype=np.int64) - - current_node = 0 - - for i in range(0, n_samples - 1): - - in_tree[current_node] = 1 - - current_node_core_dist = core_distances[current_node] - - new_reachability = DBL_MAX - source_node = 0 - new_node = 0 - - for j in range(n_samples): - if in_tree[j]: - continue - - next_node_min_reach = min_reachability[j] - next_node_source = current_sources[j] - - pair_distance = dist_metric.dist( - &raw_data[current_node, 0], - &raw_data[j, 0], - num_features - ) - - pair_distance /= alpha - - next_node_core_dist = core_distances[j] - mutual_reachability_distance = max( - current_node_core_dist, - next_node_core_dist, - pair_distance - ) - if mutual_reachability_distance > next_node_min_reach: - if next_node_min_reach < new_reachability: - new_reachability = next_node_min_reach - source_node = next_node_source - new_node = j - continue - - if mutual_reachability_distance < next_node_min_reach: - min_reachability[j] = mutual_reachability_distance - current_sources[j] = current_node - if mutual_reachability_distance < new_reachability: - new_reachability = mutual_reachability_distance - source_node = current_node - new_node = j - else: - if next_node_min_reach < new_reachability: - new_reachability = next_node_min_reach - source_node = next_node_source - new_node = j - - mst[i].current_node = source_node - mst[i].next_node = new_node - mst[i].distance = new_reachability - current_node = new_node - - return mst - -cdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix32( - const float32_t[:, ::1] raw_data, - const float32_t[::1] core_distances, - DistanceMetric32 dist_metric, - float64_t alpha=1.0 -): - """Compute the Minimum Spanning Tree (MST) representation of the mutual- - reachability graph generated from the provided `raw_data` and - `core_distances` using Prim's algorithm. - - Parameters - ---------- - raw_data : ndarray of shape (n_samples, n_features) - Input array of data samples. - - core_distances : ndarray of shape (n_samples,) - An array containing the core-distance calculated for each corresponding - sample. - - dist_metric : DistanceMetric - The distance metric to use when calculating pairwise distances for - determining mutual-reachability. - - Returns - ------- - mst : ndarray of shape (n_samples - 1,), dtype=MST_edge_dtype - The MST representation of the mutual-reahability graph. The MST is - represented as a collecteion of edges. - """ - - cdef: - uint8_t[::1] in_tree - float64_t[::1] min_reachability - int64_t[::1] current_sources - cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst - - int64_t current_node, source_node, new_node, next_node_source - int64_t i, j, n_samples, num_features - - float64_t current_node_core_dist, new_reachability, mutual_reachability_distance - float64_t next_node_min_reach, pair_distance, next_node_core_dist - - n_samples = raw_data.shape[0] - num_features = raw_data.shape[1] - - mst = np.empty(n_samples - 1, dtype=MST_edge_dtype) - - in_tree = np.zeros(n_samples, dtype=np.uint8) - min_reachability = np.full(n_samples, fill_value=np.infty, dtype=np.float64) - current_sources = np.ones(n_samples, dtype=np.int64) - - current_node = 0 - - for i in range(0, n_samples - 1): - - in_tree[current_node] = 1 - - current_node_core_dist = core_distances[current_node] - - new_reachability = DBL_MAX - source_node = 0 - new_node = 0 - - for j in range(n_samples): - if in_tree[j]: - continue - - next_node_min_reach = min_reachability[j] - next_node_source = current_sources[j] - - pair_distance = dist_metric.dist( - &raw_data[current_node, 0], - &raw_data[j, 0], - num_features - ) - - pair_distance /= alpha - - next_node_core_dist = core_distances[j] - mutual_reachability_distance = max( - current_node_core_dist, - next_node_core_dist, - pair_distance - ) - if mutual_reachability_distance > next_node_min_reach: - if next_node_min_reach < new_reachability: - new_reachability = next_node_min_reach - source_node = next_node_source - new_node = j - continue - - if mutual_reachability_distance < next_node_min_reach: - min_reachability[j] = mutual_reachability_distance - current_sources[j] = current_node - if mutual_reachability_distance < new_reachability: - new_reachability = mutual_reachability_distance - source_node = current_node - new_node = j - else: - if next_node_min_reach < new_reachability: - new_reachability = next_node_min_reach - source_node = next_node_source - new_node = j - - mst[i].current_node = source_node - mst[i].next_node = new_node - mst[i].distance = new_reachability - current_node = new_node - - return mst - -cpdef cnp.ndarray[HIERARCHY_t, ndim=1, mode="c"] make_single_linkage(const MST_edge_t[::1] mst): - """Construct a single-linkage tree from an MST. - - Parameters - ---------- - mst : ndarray of shape (n_samples - 1,), dtype=MST_edge_dtype - The MST representation of the mutual-reahability graph. The MST is - represented as a collecteion of edges. - - Returns - ------- - single_linkage : ndarray of shape (n_samples - 1,), dtype=HIERARCHY_dtype - The single-linkage tree tree (dendrogram) built from the MST. Each - of the array represents the following: - - - left node/cluster - - right node/cluster - - distance - - new cluster size - """ - cdef: - cnp.ndarray[HIERARCHY_t, ndim=1, mode="c"] single_linkage - - # Note mst.shape[0] is one fewer than the number of samples - int64_t n_samples = mst.shape[0] + 1 - intp_t current_node_cluster, next_node_cluster - int64_t current_node, next_node, i - float64_t distance - UnionFind U = UnionFind(n_samples) - - single_linkage = np.zeros(n_samples - 1, dtype=HIERARCHY_dtype) - - for i in range(n_samples - 1): - - current_node = mst[i].current_node - next_node = mst[i].next_node - distance = mst[i].distance - - current_node_cluster = U.fast_find(current_node) - next_node_cluster = U.fast_find(next_node) - - single_linkage[i].left_node = current_node_cluster - single_linkage[i].right_node = next_node_cluster - single_linkage[i].value = distance - single_linkage[i].cluster_size = U.size[current_node_cluster] + U.size[next_node_cluster] - - U.union(current_node_cluster, next_node_cluster) - - return single_linkage