From 5d0219cfe40b6fa5ac17c84d6977581760753f51 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Mon, 22 Jan 2024 17:15:37 +0500 Subject: [PATCH 1/2] Euclidean Specialization for ArgKminClassMode --- .../_argkmin_classmode.pyx.tp | 165 ++++++++++++++++-- .../_dispatcher.py | 11 -- 2 files changed, 153 insertions(+), 23 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx.tp index f9719f6959dfc..20896bb4de585 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_argkmin_classmode.pyx.tp @@ -11,9 +11,11 @@ from sklearn.utils.fixes import threadpool_limits from ._classmode cimport WeightingStrategy {{for name_suffix in ["32", "64"]}} -from ._argkmin cimport ArgKmin{{name_suffix}} + +from ._argkmin cimport ArgKmin{{name_suffix}}, EuclideanArgKmin{{name_suffix}} from ._datasets_pair cimport DatasetsPair{{name_suffix}} + cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}): """ {{name_suffix}}bit implementation of ArgKminClassMode. @@ -52,17 +54,40 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}): No instance _must_ directly be created outside of this class method. """ - # Use a generic implementation that handles most scipy - # metrics by computing the distances between 2 vectors at a time. - pda = ArgKminClassMode{{name_suffix}}( - datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), - k=k, - chunk_size=chunk_size, - strategy=strategy, - weights=weights, - Y_labels=Y_labels, - unique_Y_labels=unique_Y_labels, - ) + if metric in ("euclidean", "sqeuclidean"): + # Specialized implementation of ArgKminClassMode for the Euclidean + # distance for the dense-dense and sparse-sparse cases. + # This implementation computes the distances by chunk using + # a decomposition of the Squared Euclidean distance. + # This specialisation has an improved arithmetic intensity for both + # the dense and sparse settings, allowing in most case speed-ups of + # several orders of magnitude compared to the generic ArgKminClassMode + # implementation. + # For more information see MiddleTermComputer. + pda = EuclideanArgKminClassMode{{name_suffix}}( + X=X, + Y=Y, + k=k, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, + use_squared_distances=(metric == "sqeuclidean"), + weights=weights, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + else: + # Use a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. + pda = ArgKminClassMode{{name_suffix}}( + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), + k=k, + chunk_size=chunk_size, + strategy=strategy, + weights=weights, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, + ) # Limit the number of threads in second level of nested parallelism for BLAS # to avoid threads over-subscription (in GEMM for instance). @@ -179,4 +204,120 @@ cdef class ArgKminClassMode{{name_suffix}}(ArgKmin{{name_suffix}}): ) return + +cdef class EuclideanArgKminClassMode{{name_suffix}}(EuclideanArgKmin{{name_suffix}}): + """Euclidean Distance-specialisation of ArgKminClassMode{{name_suffix}}.""" + + cdef: + const intp_t[:] Y_labels, + const intp_t[:] unique_Y_labels + float64_t[:, :] class_scores + WeightingStrategy weight_type + + def __init__( + self, + X, + Y, + intp_t k, + Y_labels, + unique_Y_labels, + bint use_squared_distances=False, + weights=None, + chunk_size=None, + dict metric_kwargs=None, + str strategy=None, + ): + super().__init__( + X=X, + Y=Y, + k=k, + use_squared_distances=use_squared_distances, + chunk_size=chunk_size, + strategy=strategy, + metric_kwargs=metric_kwargs, + ) + + if weights == "uniform": + self.weight_type = WeightingStrategy.uniform + elif weights == "distance": + self.weight_type = WeightingStrategy.distance + else: + self.weight_type = WeightingStrategy.callable + + self.Y_labels = Y_labels + self.unique_Y_labels = unique_Y_labels + + self.class_scores = np.zeros( + (self.n_samples_X, unique_Y_labels.shape[0]), dtype=np.float64, + ) + + def _finalize_results(self): + probabilities = np.asarray(self.class_scores) + probabilities /= probabilities.sum(axis=1, keepdims=True) + return probabilities + + cdef inline void weighted_histogram_mode( + self, + intp_t sample_index, + intp_t* indices, + float64_t* distances, + ) noexcept nogil: + cdef: + intp_t neighbor_idx, neighbor_class_idx, label_index, multi_output_index + float64_t score_incr = 1 + # TODO: Implement other WeightingStrategy values + bint use_distance_weighting = ( + self.weight_type == WeightingStrategy.distance + ) + + # Iterate through the sample k-nearest neighbours + for neighbor_rank in range(self.k): + # Absolute indice of the neighbor_rank-th Nearest Neighbors + # in range [0, n_samples_Y) + if use_distance_weighting: + score_incr = 1 / distances[neighbor_rank] + neighbor_idx = indices[neighbor_rank] + neighbor_class_idx = self.Y_labels[neighbor_idx] + self.class_scores[sample_index][neighbor_class_idx] += score_incr + return + + cdef void _parallel_on_X_prange_iter_finalize( + self, + intp_t thread_num, + intp_t X_start, + intp_t X_end, + ) noexcept nogil: + cdef: + intp_t idx, sample_index + for idx in range(X_end - X_start): + # One-pass top-one weighted mode + # Compute the absolute index in [0, n_samples_X) + sample_index = X_start + idx + self.weighted_histogram_mode( + sample_index, + &self.heaps_indices_chunks[thread_num][idx * self.k], + &self.heaps_r_distances_chunks[thread_num][idx * self.k], + ) + return + + cdef void _parallel_on_Y_finalize( + self, + ) noexcept nogil: + cdef: + intp_t sample_index, thread_num + + with nogil, parallel(num_threads=self.chunks_n_threads): + # Deallocating temporary datastructures + for thread_num in prange(self.chunks_n_threads, schedule='static'): + free(self.heaps_r_distances_chunks[thread_num]) + free(self.heaps_indices_chunks[thread_num]) + + for sample_index in prange(self.n_samples_X, schedule='static'): + self.weighted_histogram_mode( + sample_index, + &self.argkmin_indices[sample_index][0], + &self.argkmin_distances[sample_index][0], + ) + return + {{endfor}} diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 1088fa86e7c9c..ddabf1e93b6b5 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -454,17 +454,6 @@ class ArgKminClassMode(BaseDistancesReductionDispatcher): deallocation consistently. """ - @classmethod - def valid_metrics(cls) -> List[str]: - excluded = { - # Euclidean is technically usable for ArgKminClassMode - # but its current implementation would not be competitive. - # TODO: implement Euclidean specialization using GEMM. - "euclidean", - "sqeuclidean", - } - return list(set(BaseDistancesReductionDispatcher.valid_metrics()) - excluded) - @classmethod def compute( cls, From dc14320caea1586cac3f80696ff196bd7d222daa Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 25 Jan 2024 17:12:41 +0500 Subject: [PATCH 2/2] Add strategy auto for euclidean metrics --- sklearn/neighbors/_classification.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 26ffa273d0a60..9aca6e8d07f7d 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -357,7 +357,11 @@ def predict_proba(self, X): # https://github.com/scikit-learn/scikit-learn/pull/24076#issuecomment-1445258342 # noqa # TODO: adapt the heuristic for `strategy="auto"` for # `ArgKminClassMode` and use `strategy="auto"`. - strategy="parallel_on_X", + strategy=( + "auto" + if metric in ("euclidean", "seuclidean") + else "parallel_on_X" + ), ) return probabilities