From 6b2700a0445774718e6ada21bc6aec68cdda8a64 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Tue, 11 Jul 2023 18:31:26 +0500 Subject: [PATCH 01/10] PERF Implement PairwiseDistancesReduction backend for RadiusNeighbors.predict_proba --- .gitignore | 1 + setup.cfg | 1 + setup.py | 6 + .../_pairwise_distances_reduction/__init__.py | 2 + .../_dispatcher.py | 170 ++++++++++++++ .../_radius_neighbors_classmode.pyx.tp | 221 ++++++++++++++++++ sklearn/neighbors/_classification.py | 148 +++++++----- 7 files changed, 488 insertions(+), 61 deletions(-) create mode 100644 sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp diff --git a/.gitignore b/.gitignore index f4601a15655a5..cbcbb4a394a51 100644 --- a/.gitignore +++ b/.gitignore @@ -99,6 +99,7 @@ sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pxd sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx +sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx # Default JupyterLite content jupyterlite_contents diff --git a/setup.cfg b/setup.cfg index d91a27344c575..3c8db85e48b86 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,6 +54,7 @@ ignore = sklearn/metrics/_pairwise_distances_reduction/_middle_term_computer.pyx sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pxd sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx + sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx [codespell] diff --git a/setup.py b/setup.py index 5af738f5f841f..82ac6b165d1b2 100755 --- a/setup.py +++ b/setup.py @@ -295,6 +295,12 @@ def check_package_status(package, min_version): "include_np": True, "extra_compile_args": ["-std=c++11"], }, + { + "sources": ["_radius_neighbors_classmode.pyx.tp"], + "language": "c++", + "include_np": True, + "extra_compile_args": ["-std=c++11"], + }, ], "preprocessing": [ {"sources": ["_csr_polynomial_expansion.pyx"]}, diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py index 68972de0a1a51..9352cab82652a 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/__init__.py +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -91,6 +91,7 @@ ArgKminClassMode, BaseDistancesReductionDispatcher, RadiusNeighbors, + RadiusNeighborsClassMode, sqeuclidean_row_norms, ) @@ -98,6 +99,7 @@ "BaseDistancesReductionDispatcher", "ArgKmin", "RadiusNeighbors", + "RadiusNeighborsClassMode", "ArgKminClassMode", "sqeuclidean_row_norms", ] diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 42f9e38aa2265..4c7b59634b188 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -19,6 +19,10 @@ RadiusNeighbors32, RadiusNeighbors64, ) +from ._radius_neighbors_classmode import ( + RadiusNeighborsClassMode32, + RadiusNeighborsClassMode64, +) def sqeuclidean_row_norms(X, num_threads): @@ -612,3 +616,169 @@ def compute( "Only float64 or float32 datasets pairs are supported at this time, " f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." ) + + +class RadiusNeighborsClassMode(BaseDistancesReductionDispatcher): + """Compute radius-based neighbors of row vectors of X on the ones of + Y with labels. + + For each row-vector X[i] of the queries X, find all the indices j of + row-vectors in Y such that: + + dist(X[i], Y[j]) <= radius + + RadiusNeighborsClassMode is typically used to perform bruteforce + radius neighbors queries when the weighted mode of the labels for + the nearest neighbors within the specified radius are required, + such as in `predict` methods. + + This class is not meant to be instantiated, one should only use + its :meth:`compute` classmethod which handles allocation and + deallocation consistently. + """ + + @classmethod + def is_usable_for(cls, X, Y, metric) -> bool: + """Return True if the dispatcher can be used for the given parameters. + + Parameters + ---------- + X : ndarray of shape (n_samples_X, n_features) + The input array to be labelled. + + Y : ndarray of shape (n_samples_Y, n_features) + The input array whose labels are provided through the `labels` + parameter. + + metric : str, default='euclidean' + The distance metric to use. For a list of available metrics, see + the documentation of :class:`~sklearn.metrics.DistanceMetric`. + Currently does not support `'precomputed'`. + + Returns + ------- + True if the PairwiseDistancesReduction can be used, else False. + """ + return ( + RadiusNeighbors.is_usable_for(X, Y, metric) + # TODO: Support CSR matrices. + and not issparse(X) + and not issparse(Y) + # TODO: implement Euclidean specialization with GEMM. + and metric not in ("euclidean", "sqeuclidean") + ) + + @classmethod + def compute( + cls, + X, + Y, + radius, + weights, + labels, + unique_labels, + outlier_label, + metric="euclidean", + chunk_size=None, + metric_kwargs=None, + strategy=None, + ): + """Return the results of the reduction for the given arguments. + Parameters + ---------- + X : ndarray of shape (n_samples_X, n_features) + The input array to be labelled. + Y : ndarray of shape (n_samples_Y, n_features) + The input array whose labels are provided through the `labels` + parameter. + radius : float + The radius defining the neighborhood. + weights : ndarray + The weights applied over the `labels` of `Y` when computing the + weighted mode of the labels. + labels : ndarray + An array containing the index of the class membership of the + associated samples in `Y`. This is used in labeling `X`. + unique_classes : ndarray + An array containing all unique class labels. + outlier_label : int, default=None + Label for outlier samples (samples with no neighbors in given + radius). + metric : str, default='euclidean' + The distance metric to use. For a list of available metrics, see + the documentation of :class:`~sklearn.metrics.DistanceMetric`. + Currently does not support `'precomputed'`. + chunk_size : int, default=None, + The number of vectors per chunk. If None (default) looks-up in + scikit-learn configuration for `pairwise_dist_chunk_size`, + and use 256 if it is not set. + metric_kwargs : dict, default=None + Keyword arguments to pass to specified metric function. + strategy : str, {'auto', 'parallel_on_X', 'parallel_on_Y'}, default=None + The chunking strategy defining which dataset parallelization are made on. + For both strategies the computations happens with two nested loops, + respectively on chunks of X and chunks of Y. + Strategies differs on which loop (outer or inner) is made to run + in parallel with the Cython `prange` construct: + - 'parallel_on_X' dispatches chunks of X uniformly on threads. + Each thread then iterates on all the chunks of Y. This strategy is + embarrassingly parallel and comes with no datastructures + synchronisation. + - 'parallel_on_Y' dispatches chunks of Y uniformly on threads. + Each thread processes all the chunks of X in turn. This strategy is + a sequence of embarrassingly parallel subtasks (the inner loop on Y + chunks) with intermediate datastructures synchronisation at each + iteration of the sequential outer loop on X chunks. + - 'auto' relies on a simple heuristic to choose between + 'parallel_on_X' and 'parallel_on_Y': when `X.shape[0]` is large enough, + 'parallel_on_X' is usually the most efficient strategy. + When `X.shape[0]` is small but `Y.shape[0]` is large, 'parallel_on_Y' + brings more opportunity for parallelism and is therefore more efficient + despite the synchronization step at each iteration of the outer loop + on chunks of `X`. + - None (default) looks-up in scikit-learn configuration for + `pairwise_dist_parallel_strategy`, and use 'auto' if it is not set. + Returns + ------- + probabilities : ndarray of shape (n_samples_X, n_classes) + An array containing the class probabilities for each sample. + """ + if weights not in {"uniform", "distance"}: + raise ValueError( + "Only the 'uniform' or 'distance' weights options are supported" + f" at this time. Got: {weights=}." + ) + if X.dtype == Y.dtype == np.float64: + return RadiusNeighborsClassMode64.compute( + X=X, + Y=Y, + radius=radius, + weights=weights, + class_membership=np.array(labels, dtype=np.intp), + unique_labels=np.array(unique_labels, dtype=np.intp), + outlier_label=outlier_label, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + + if X.dtype == Y.dtype == np.float32: + return RadiusNeighborsClassMode32.compute( + X=X, + Y=Y, + radius=radius, + weights=weights, + class_membership=np.array(labels, dtype=np.intp), + unique_labels=np.array(unique_labels, dtype=np.intp), + outlier_label=outlier_label, + metric=metric, + chunk_size=chunk_size, + metric_kwargs=metric_kwargs, + strategy=strategy, + ) + + raise ValueError( + "Only float64 or float32 datasets pairs are supported at this time, " + f"got: X.dtype={X.dtype} and Y.dtype={Y.dtype}." + ) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp new file mode 100644 index 0000000000000..1b0e20caf02be --- /dev/null +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp @@ -0,0 +1,221 @@ +import warnings + +from cython cimport floating, final, integral +from cython.operator cimport dereference as deref +from cython.parallel cimport parallel, prange +from libcpp.vector cimport vector + +from ...utils._typedefs cimport intp_t, float64_t + +import numpy as np +from scipy.sparse import issparse +from sklearn.utils.fixes import threadpool_limits + +cpdef enum WeightingStrategy: + uniform = 0 + distance = 1 + callable = 2 + + +{{for name_suffix in ["32", "64"]}} +from ._radius_neighbors cimport RadiusNeighbors{{name_suffix}} +from ._datasets_pair cimport DatasetsPair{{name_suffix}} + +cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix}}): + """ + {{name_suffix}}bit implementation of RadiusNeighborsClassMode. + """ + cdef: + const intp_t[:] class_membership + const intp_t[:] unique_labels + intp_t outlier_index + bint outlier_label_exists + vector[intp_t] outliers + float64_t[:, :] class_scores + WeightingStrategy weight_type + object outlier_label + + @classmethod + def compute( + cls, + X, + Y, + float64_t radius, + weights, + class_membership, + unique_labels, + outlier_label=None, + str metric="euclidean", + chunk_size=None, + dict metric_kwargs=None, + str strategy=None, + ): + # Use a generic implementation that handles most scipy + # metrics by computing the distances between 2 vectors at a time. + pda = RadiusNeighborsClassMode{{name_suffix}}( + datasets_pair=DatasetsPair{{name_suffix}}.get_for(X, Y, metric, metric_kwargs), + radius=radius, + chunk_size=chunk_size, + strategy=strategy, + weights=weights, + class_membership=class_membership, + unique_labels=unique_labels, + outlier_label=outlier_label, + ) + + # Limit the number of threads in second level of nested parallelism for BLAS + # to avoid threads over-subscription (in GEMM for instance). + with threadpool_limits(limits=1, user_api="blas"): + if pda.execute_in_parallel_on_Y: + pda._parallel_on_Y() + else: + pda._parallel_on_X() + + return pda._finalize_results() + + def __init__( + self, + DatasetsPair{{name_suffix}} datasets_pair, + const intp_t[:] class_membership, + const intp_t[:] unique_labels, + float64_t radius, + chunk_size=None, + strategy=None, + weights=None, + outlier_label=None, + ): + super().__init__( + datasets_pair=datasets_pair, + chunk_size=chunk_size, + strategy=strategy, + radius=radius, + ) + + if weights == "uniform": + self.weight_type = WeightingStrategy.uniform + elif weights == "distance": + self.weight_type = WeightingStrategy.distance + else: + self.weight_type = WeightingStrategy.callable + + self.class_membership = class_membership + self.unique_labels = unique_labels + self.outlier_index = -1 + self.outlier_label_exists = False + self.outlier_label = outlier_label + + cdef intp_t idx + if outlier_label is not None: + self.outlier_label_exists = True + for idx in range(self.unique_labels.shape[0]): + if self.unique_labels[idx] == outlier_label: + self.outlier_index = idx + + # Map from set of unique labels to their indices in `class_scores` + # Buffer used in building a histogram for one-pass weighted mode + self.class_scores = np.zeros( + (self.n_samples_X, unique_labels.shape[0]), dtype=np.float64, + ) + + + cdef inline void weighted_histogram_mode( + self, + intp_t sample_index, + intp_t k, + intp_t* indices, + float64_t* distances, + ) noexcept nogil: + cdef: + intp_t neighbor_idx, neighbor_class_idx, label_index + float64_t score_incr = 1 + bint use_distance_weighting = ( + self.weight_type == WeightingStrategy.distance + ) + + if k == 0: + self.outliers.push_back(sample_index) + if self.outlier_index >= 0: + self.class_scores[sample_index][self.outlier_index] = score_incr + + return + + # Iterate over the neighbors. This can be different for + # each of the samples as they are based on the radius. + for neighbor_rank in range(k): + if use_distance_weighting: + score_incr = 1 / distances[neighbor_rank] + + neighbor_idx = indices[neighbor_rank] + neighbor_class_idx = self.class_membership[neighbor_idx] + self.class_scores[sample_index][neighbor_class_idx] += score_incr + + return + + @final + cdef void _parallel_on_X_prange_iter_finalize( + self, + intp_t thread_num, + intp_t X_start, + intp_t X_end, + ) noexcept nogil: + cdef: + intp_t idx, sample_index + + for idx in range(X_start, X_end): + sample_index = X_start + idx + self.weighted_histogram_mode( + sample_index=sample_index, + k=deref(self.neigh_indices)[idx].size(), + indices=deref(self.neigh_indices)[idx].data(), + distances=deref(self.neigh_distances)[idx].data(), + ) + + return + + cdef void _parallel_on_Y_finalize( + self, + ) noexcept nogil: + cdef: + intp_t idx + + with nogil, parallel(num_threads=self.effective_n_threads): + # Merge vectors used in threads into the main ones. + # This is done in parallel sample-wise (no need for locks). + for idx in prange(self.n_samples_X, schedule='static'): + self._merge_vectors(idx, self.chunks_n_threads) + + for idx in prange(self.n_samples_X, schedule='static'): + self.weighted_histogram_mode( + sample_index=idx, + k=deref(self.neigh_indices)[idx].size(), + indices=deref(self.neigh_indices)[idx].data(), + distances=deref(self.neigh_distances)[idx].data(), + ) + + return + + def _finalize_results(self): + if not self.outlier_label_exists and self.outliers.size() > 0: + raise ValueError( + "No neighbors found for test samples %r, " + "you can try using larger radius, " + "giving a label for outliers, " + "or considering removing them from your dataset." + % self.outliers + ) + + if self.outliers.size() > 0 and self.outlier_index < 0: + warnings.warn( + "Outlier label %s is not in training " + "classes. All class probabilities of " + "outliers will be assigned with 0." + % self.outlier_label + ) + + probabilities = np.asarray(self.class_scores) + normalizer = probabilities.sum(axis=1, keepdims=True) + normalizer[normalizer == 0.0] = 1.0 + probabilities /= normalizer + return probabilities + +{{endfor}} diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 7f6242ab29001..d802fcb56f305 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -15,7 +15,10 @@ from sklearn.neighbors._base import _check_precomputed from ..base import ClassifierMixin, _fit_context -from ..metrics._pairwise_distances_reduction import ArgKminClassMode +from ..metrics._pairwise_distances_reduction import ( + ArgKminClassMode, + RadiusNeighborsClassMode, +) from ..utils._param_validation import StrOptions from ..utils.extmath import weighted_mode from ..utils.fixes import _mode @@ -707,75 +710,98 @@ def predict_proba(self, X): The class probabilities of the input samples. Classes are ordered by lexicographic order. """ - + check_is_fitted(self, "_fit_method") n_queries = _num_samples(X) - neigh_dist, neigh_ind = self.radius_neighbors(X) - outlier_mask = np.zeros(n_queries, dtype=bool) - outlier_mask[:] = [len(nind) == 0 for nind in neigh_ind] - outliers = np.flatnonzero(outlier_mask) - inliers = np.flatnonzero(~outlier_mask) - - classes_ = self.classes_ - _y = self._y - if not self.outputs_2d_: - _y = self._y.reshape((-1, 1)) - classes_ = [self.classes_] + metric, metric_kwargs = _adjusted_metric( + metric=self.metric, metric_kwargs=self.metric_params, p=self.p + ) - if self.outlier_label_ is None and outliers.size > 0: - raise ValueError( - "No neighbors found for test samples %r, " - "you can try using larger radius, " - "giving a label for outliers, " - "or considering removing them from your dataset." % outliers + if ( + self.weights == "uniform" + and self._fit_method == "brute" + and not self.outputs_2d_ + and RadiusNeighborsClassMode.is_usable_for(X, self._fit_X, metric) + ): + probabilities = RadiusNeighborsClassMode.compute( + X=X, + Y=self._fit_X, + radius=self.radius, + weights=self.weights, + labels=self._y, + unique_labels=self.classes_, + outlier_label=self.outlier_label, + metric=metric, + metric_kwargs=metric_kwargs, + strategy="parallel_on_X", ) + else: + neigh_dist, neigh_ind = self.radius_neighbors(X) + outlier_mask = np.zeros(n_queries, dtype=bool) + outlier_mask[:] = [len(nind) == 0 for nind in neigh_ind] + outliers = np.flatnonzero(outlier_mask) + inliers = np.flatnonzero(~outlier_mask) + + classes_ = self.classes_ + _y = self._y + if not self.outputs_2d_: + _y = self._y.reshape((-1, 1)) + classes_ = [self.classes_] + + if self.outlier_label_ is None and outliers.size > 0: + raise ValueError( + "No neighbors found for test samples %r, " + "you can try using larger radius, " + "giving a label for outliers, " + "or considering removing them from your dataset." % outliers + ) - weights = _get_weights(neigh_dist, self.weights) - if weights is not None: - weights = weights[inliers] - - probabilities = [] - # iterate over multi-output, measure probabilities of the k-th output. - for k, classes_k in enumerate(classes_): - pred_labels = np.zeros(len(neigh_ind), dtype=object) - pred_labels[:] = [_y[ind, k] for ind in neigh_ind] + weights = _get_weights(neigh_dist, self.weights) + if weights is not None: + weights = weights[inliers] - proba_k = np.zeros((n_queries, classes_k.size)) - proba_inl = np.zeros((len(inliers), classes_k.size)) + probabilities = [] + # iterate over multi-output, measure probabilities of the k-th output. + for k, classes_k in enumerate(classes_): + pred_labels = np.zeros(len(neigh_ind), dtype=object) + pred_labels[:] = [_y[ind, k] for ind in neigh_ind] - # samples have different size of neighbors within the same radius - if weights is None: - for i, idx in enumerate(pred_labels[inliers]): - proba_inl[i, :] = np.bincount(idx, minlength=classes_k.size) - else: - for i, idx in enumerate(pred_labels[inliers]): - proba_inl[i, :] = np.bincount( - idx, weights[i], minlength=classes_k.size - ) - proba_k[inliers, :] = proba_inl + proba_k = np.zeros((n_queries, classes_k.size)) + proba_inl = np.zeros((len(inliers), classes_k.size)) - if outliers.size > 0: - _outlier_label = self.outlier_label_[k] - label_index = np.flatnonzero(classes_k == _outlier_label) - if label_index.size == 1: - proba_k[outliers, label_index[0]] = 1.0 + # samples have different size of neighbors within the same radius + if weights is None: + for i, idx in enumerate(pred_labels[inliers]): + proba_inl[i, :] = np.bincount(idx, minlength=classes_k.size) else: - warnings.warn( - "Outlier label {} is not in training " - "classes. All class probabilities of " - "outliers will be assigned with 0." - "".format(self.outlier_label_[k]) - ) - - # normalize 'votes' into real [0,1] probabilities - normalizer = proba_k.sum(axis=1)[:, np.newaxis] - normalizer[normalizer == 0.0] = 1.0 - proba_k /= normalizer - - probabilities.append(proba_k) - - if not self.outputs_2d_: - probabilities = probabilities[0] + for i, idx in enumerate(pred_labels[inliers]): + proba_inl[i, :] = np.bincount( + idx, weights[i], minlength=classes_k.size + ) + proba_k[inliers, :] = proba_inl + + if outliers.size > 0: + _outlier_label = self.outlier_label_[k] + label_index = np.flatnonzero(classes_k == _outlier_label) + if label_index.size == 1: + proba_k[outliers, label_index[0]] = 1.0 + else: + warnings.warn( + "Outlier label {} is not in training " + "classes. All class probabilities of " + "outliers will be assigned with 0." + "".format(self.outlier_label_[k]) + ) + + # normalize 'votes' into real [0,1] probabilities + normalizer = proba_k.sum(axis=1)[:, np.newaxis] + normalizer[normalizer == 0.0] = 1.0 + proba_k /= normalizer + + probabilities.append(proba_k) + + if not self.outputs_2d_: + probabilities = probabilities[0] return probabilities From 0ad31271c8fcbf6d2df2591eab8f193a117aabf8 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 13 Jul 2023 15:11:04 +0500 Subject: [PATCH 02/10] FIX index issue in _parallel_on_X_prange_iter_finalize --- .../_radius_neighbors_classmode.pyx.tp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp index 1b0e20caf02be..1fb1c83a9fc9a 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp @@ -159,12 +159,11 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} intp_t X_end, ) noexcept nogil: cdef: - intp_t idx, sample_index + intp_t idx for idx in range(X_start, X_end): - sample_index = X_start + idx self.weighted_histogram_mode( - sample_index=sample_index, + sample_index=idx, k=deref(self.neigh_indices)[idx].size(), indices=deref(self.neigh_indices)[idx].data(), distances=deref(self.neigh_distances)[idx].data(), From c46e590ea652c7331e2b29ea6fff9d77b2d8f3e8 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Fri, 14 Jul 2023 14:25:03 +0500 Subject: [PATCH 03/10] Add tests for wrong method usages --- .../test_pairwise_distances_reduction.py | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 5fcf980fbe39b..cd53426732a53 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -16,6 +16,7 @@ ArgKminClassMode, BaseDistancesReductionDispatcher, RadiusNeighbors, + RadiusNeighborsClassMode, sqeuclidean_row_norms, ) from sklearn.utils._testing import ( @@ -851,6 +852,116 @@ def test_radius_neighbors_factory_method_wrong_usages(): ) +def test_radius_neighbors_classmode_factory_method_wrong_usages(): + rng = np.random.RandomState(1) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + radius = 5 + metric = "manhattan" + weights = "uniform" + labels = rng.randint(low=0, high=10, size=100) + unique_labels = np.unique(labels) + + msg = ( + "Only float64 or float32 datasets pairs are supported at this time, " + "got: X.dtype=float32 and Y.dtype=float64" + ) + with pytest.raises(ValueError, match=msg): + RadiusNeighborsClassMode.compute( + X=X.astype(np.float32), + Y=Y, + radius=radius, + metric=metric, + weights=weights, + labels=labels, + unique_labels=unique_labels, + outlier_label=None, + ) + + msg = ( + "Only float64 or float32 datasets pairs are supported at this time, " + "got: X.dtype=float64 and Y.dtype=int32" + ) + with pytest.raises(ValueError, match=msg): + RadiusNeighborsClassMode.compute( + X=X, + Y=Y.astype(np.int32), + radius=radius, + metric=metric, + weights=weights, + labels=labels, + unique_labels=unique_labels, + outlier_label=None, + ) + + with pytest.raises(ValueError, match="radius == -1.0, must be >= 0."): + RadiusNeighborsClassMode.compute( + X=X, + Y=Y, + radius=-1, + metric=metric, + weights=weights, + labels=labels, + unique_labels=unique_labels, + outlier_label=None, + ) + + with pytest.raises(ValueError, match="Unrecognized metric"): + RadiusNeighborsClassMode.compute( + X=X, + Y=Y, + radius=-1, + metric="wrong_metric", + weights=weights, + labels=labels, + unique_labels=unique_labels, + outlier_label=None, + ) + + with pytest.raises( + ValueError, match=r"Buffer has wrong number of dimensions \(expected 2, got 1\)" + ): + RadiusNeighborsClassMode.compute( + X=np.array([1.0, 2.0]), + Y=Y, + radius=radius, + metric=metric, + weights=weights, + labels=labels, + unique_labels=unique_labels, + outlier_label=None, + ) + + with pytest.raises(ValueError, match="ndarray is not C-contiguous"): + RadiusNeighborsClassMode.compute( + X=np.asfortranarray(X), + Y=Y, + radius=radius, + metric=metric, + weights=weights, + labels=labels, + unique_labels=unique_labels, + outlier_label=None, + ) + + non_existent_weights_strategy = "non_existent_weights_strategy" + msg = ( + "Only the 'uniform' or 'distance' weights options are supported at this time. " + f"Got: weights='{non_existent_weights_strategy}'." + ) + with pytest.raises(ValueError, match=msg): + RadiusNeighborsClassMode.compute( + X=X, + Y=Y, + radius=radius, + metric="wrong_metric", + weights=non_existent_weights_strategy, + labels=labels, + unique_labels=unique_labels, + outlier_label=None, + ) + + @pytest.mark.parametrize( "n_samples_X, n_samples_Y", [(100, 100), (500, 100), (100, 500)] ) From 04dd4b26d5dbf7e72386d86e79ba2a5e4c674046 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Fri, 14 Jul 2023 17:14:05 +0500 Subject: [PATCH 04/10] * Use a mem view and numpy array to handle outliers * Add a test to check results are consistent with both strategies --- .../_radius_neighbors_classmode.pyx.tp | 17 +++++---- .../test_pairwise_distances_reduction.py | 35 +++++++++++++++++++ 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp index 1fb1c83a9fc9a..33ba222a8a1c0 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp @@ -3,7 +3,6 @@ import warnings from cython cimport floating, final, integral from cython.operator cimport dereference as deref from cython.parallel cimport parallel, prange -from libcpp.vector cimport vector from ...utils._typedefs cimport intp_t, float64_t @@ -30,10 +29,11 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} const intp_t[:] unique_labels intp_t outlier_index bint outlier_label_exists - vector[intp_t] outliers + bint outliers_exist + unsigned char[::1] outliers + object outlier_label float64_t[:, :] class_scores WeightingStrategy weight_type - object outlier_label @classmethod def compute( @@ -102,7 +102,9 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} self.unique_labels = unique_labels self.outlier_index = -1 self.outlier_label_exists = False + self.outliers_exist = False self.outlier_label = outlier_label + self.outliers = np.zeros(self.n_samples_X, dtype=np.bool_) cdef intp_t idx if outlier_label is not None: @@ -133,7 +135,8 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} ) if k == 0: - self.outliers.push_back(sample_index) + self.outliers_exist = True + self.outliers[sample_index] = True if self.outlier_index >= 0: self.class_scores[sample_index][self.outlier_index] = score_incr @@ -194,16 +197,16 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} return def _finalize_results(self): - if not self.outlier_label_exists and self.outliers.size() > 0: + if self.outliers_exist and not self.outlier_label_exists: raise ValueError( "No neighbors found for test samples %r, " "you can try using larger radius, " "giving a label for outliers, " "or considering removing them from your dataset." - % self.outliers + % np.where(self.outliers)[0].tolist() ) - if self.outliers.size() > 0 and self.outlier_index < 0: + if self.outliers_exist and self.outlier_index < 0: warnings.warn( "Outlier label %s is not in training " "classes. All class probabilities of " diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index cd53426732a53..91327adbbbb2e 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1466,3 +1466,38 @@ def test_argkmin_classmode_strategy_consistent(): strategy="parallel_on_Y", ) assert_array_equal(results_X, results_Y) + + +def test_radius_neighbors_classmode_strategy_consistent(): + rng = np.random.RandomState(1) + X = rng.rand(100, 10) + Y = rng.rand(100, 10) + radius = 5 + metric = "manhattan" + + weights = "uniform" + labels = rng.randint(low=0, high=10, size=100) + unique_labels = np.unique(labels) + results_X = RadiusNeighborsClassMode.compute( + X=X, + Y=Y, + radius=radius, + metric=metric, + weights=weights, + labels=labels, + unique_labels=unique_labels, + outlier_label=None, + strategy="parallel_on_X", + ) + results_Y = RadiusNeighborsClassMode.compute( + X=X, + Y=Y, + radius=radius, + metric=metric, + weights=weights, + labels=labels, + unique_labels=unique_labels, + outlier_label=None, + strategy="parallel_on_Y", + ) + assert_allclose(results_X, results_Y) From 495e97ed8875abc8aeba5b079ab1b23072f4acbf Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Fri, 14 Jul 2023 17:26:48 +0500 Subject: [PATCH 05/10] Remove tolist() when printing outliers in error --- .../_radius_neighbors_classmode.pyx.tp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp index 33ba222a8a1c0..e6bbbb9ebc1e8 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp @@ -203,7 +203,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} "you can try using larger radius, " "giving a label for outliers, " "or considering removing them from your dataset." - % np.where(self.outliers)[0].tolist() + % np.where(self.outliers)[0] ) if self.outliers_exist and self.outlier_index < 0: From 65d8b966277040d612c0a96801f574d65e210c03 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Thu, 20 Jul 2023 14:53:22 +0500 Subject: [PATCH 06/10] Add changelog --- doc/whats_new/v1.4.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index c2b7d19404af9..a7d373b5814c8 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -120,6 +120,14 @@ Changelog object in the parameter grid if it's an estimator. :pr:`26786` by `Adrin Jalali`_. +:mod:`sklearn.neighbors` +........................ + +- |Enhancement| The performance of :meth:`neighbors.RadiusNeighborsClassifier.predict` + and of :meth:`neighbors.RadiusNeighborsClassifier.predict_proba` has been improved + when `radius` is large and `algorithm="brute"` with non Euclidean metrics. + :pr:`26828` by :user:`Omar Salman `. + :mod:`sklearn.tree` ................... From 31548431518d0e09207870fc37647afd786a187b Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Mon, 7 Aug 2023 15:18:46 +0500 Subject: [PATCH 07/10] Applied PR suggestions --- doc/whats_new/v1.4.rst | 4 +- .../_pairwise_distances_reduction/__init__.py | 2 +- .../_dispatcher.py | 58 +++++--- .../_radius_neighbors_classmode.pyx.tp | 46 +++--- .../test_pairwise_distances_reduction.py | 44 +++--- sklearn/neighbors/_classification.py | 139 ++++++++++-------- 6 files changed, 161 insertions(+), 132 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index a7d373b5814c8..d5ecf51db2066 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -123,9 +123,9 @@ Changelog :mod:`sklearn.neighbors` ........................ -- |Enhancement| The performance of :meth:`neighbors.RadiusNeighborsClassifier.predict` +- |Efficiency| The performance of :meth:`neighbors.RadiusNeighborsClassifier.predict` and of :meth:`neighbors.RadiusNeighborsClassifier.predict_proba` has been improved - when `radius` is large and `algorithm="brute"` with non Euclidean metrics. + when `radius` is large and `algorithm="brute"` with non-Euclidean metrics. :pr:`26828` by :user:`Omar Salman `. :mod:`sklearn.tree` diff --git a/sklearn/metrics/_pairwise_distances_reduction/__init__.py b/sklearn/metrics/_pairwise_distances_reduction/__init__.py index 9352cab82652a..5c9366945e8cc 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/__init__.py +++ b/sklearn/metrics/_pairwise_distances_reduction/__init__.py @@ -99,7 +99,7 @@ "BaseDistancesReductionDispatcher", "ArgKmin", "RadiusNeighbors", - "RadiusNeighborsClassMode", "ArgKminClassMode", + "RadiusNeighborsClassMode", "sqeuclidean_row_norms", ] diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index c8659815c131e..26c46068164fc 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -620,8 +620,8 @@ def compute( class RadiusNeighborsClassMode(BaseDistancesReductionDispatcher): - """Compute radius-based neighbors of row vectors of X on the ones of - Y with labels. + """Compute radius-based class modes of row vectors of X using the + those of Y. For each row-vector X[i] of the queries X, find all the indices j of row-vectors in Y such that: @@ -638,6 +638,24 @@ class RadiusNeighborsClassMode(BaseDistancesReductionDispatcher): deallocation consistently. """ + @classmethod + def valid_metrics(cls) -> List[str]: + excluded = { + # PyFunc cannot be supported because it necessitates interacting with + # the CPython interpreter to call user defined functions. + "pyfunc", + "mahalanobis", # is numerically unstable + # In order to support discrete distance metrics, we need to have a + # stable simultaneous sort which preserves the order of the indices + # because there generally is a lot of occurrences for a given values + # of distances in this case. + # TODO: implement a stable simultaneous_sort. + "hamming", + "euclidean", + *BOOL_METRICS, + } + return sorted(set(METRIC_MAPPING64.keys()) - excluded) + @classmethod def is_usable_for(cls, X, Y, metric) -> bool: """Return True if the dispatcher can be used for the given parameters. @@ -648,7 +666,7 @@ def is_usable_for(cls, X, Y, metric) -> bool: The input array to be labelled. Y : ndarray of shape (n_samples_Y, n_features) - The input array whose labels are provided through the `labels` + The input array whose labels are provided through the `Y_labels` parameter. metric : str, default='euclidean' @@ -662,11 +680,7 @@ def is_usable_for(cls, X, Y, metric) -> bool: """ return ( RadiusNeighbors.is_usable_for(X, Y, metric) - # TODO: Support CSR matrices. - and not issparse(X) - and not issparse(Y) - # TODO: implement Euclidean specialization with GEMM. - and metric not in ("euclidean", "sqeuclidean") + and metric in cls.valid_metrics() ) @classmethod @@ -676,8 +690,8 @@ def compute( Y, radius, weights, - labels, - unique_labels, + Y_labels, + unique_Y_labels, outlier_label, metric="euclidean", chunk_size=None, @@ -690,21 +704,25 @@ def compute( X : ndarray of shape (n_samples_X, n_features) The input array to be labelled. Y : ndarray of shape (n_samples_Y, n_features) - The input array whose labels are provided through the `labels` - parameter. + The input array whose class membership is provided through + the `Y_labels` parameter. radius : float The radius defining the neighborhood. weights : ndarray - The weights applied over the `labels` of `Y` when computing the + The weights applied to the `Y_labels` when computing the weighted mode of the labels. - labels : ndarray + Y_labels : ndarray An array containing the index of the class membership of the associated samples in `Y`. This is used in labeling `X`. - unique_classes : ndarray + unique_Y_labels : ndarray An array containing all unique class labels. outlier_label : int, default=None Label for outlier samples (samples with no neighbors in given - radius). + radius). In the default case when the value is None if any + outlier is detected, a ValueError will be raised. The outlier + label should be selected from among the unique 'Y' labels. If + it is specified with a different value a warning will be raised + and all class probabilities of outliers will be assigned to be 0. metric : str, default='euclidean' The distance metric to use. For a list of available metrics, see the documentation of :class:`~sklearn.metrics.DistanceMetric`. @@ -755,8 +773,8 @@ def compute( Y=Y, radius=radius, weights=weights, - class_membership=np.array(labels, dtype=np.intp), - unique_labels=np.array(unique_labels, dtype=np.intp), + Y_labels=np.array(Y_labels, dtype=np.intp), + unique_Y_labels=np.array(unique_Y_labels, dtype=np.intp), outlier_label=outlier_label, metric=metric, chunk_size=chunk_size, @@ -770,8 +788,8 @@ def compute( Y=Y, radius=radius, weights=weights, - class_membership=np.array(labels, dtype=np.intp), - unique_labels=np.array(unique_labels, dtype=np.intp), + Y_labels=np.array(Y_labels, dtype=np.intp), + unique_Y_labels=np.array(unique_Y_labels, dtype=np.intp), outlier_label=outlier_label, metric=metric, chunk_size=chunk_size, diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp index e6bbbb9ebc1e8..19f560f7f8e7a 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp @@ -25,9 +25,9 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} {{name_suffix}}bit implementation of RadiusNeighborsClassMode. """ cdef: - const intp_t[:] class_membership - const intp_t[:] unique_labels - intp_t outlier_index + const intp_t[:] Y_labels + const intp_t[:] unique_Y_labels + intp_t outlier_label_index bint outlier_label_exists bint outliers_exist unsigned char[::1] outliers @@ -42,8 +42,8 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} Y, float64_t radius, weights, - class_membership, - unique_labels, + Y_labels, + unique_Y_labels, outlier_label=None, str metric="euclidean", chunk_size=None, @@ -58,8 +58,8 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} chunk_size=chunk_size, strategy=strategy, weights=weights, - class_membership=class_membership, - unique_labels=unique_labels, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, outlier_label=outlier_label, ) @@ -76,8 +76,8 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} def __init__( self, DatasetsPair{{name_suffix}} datasets_pair, - const intp_t[:] class_membership, - const intp_t[:] unique_labels, + const intp_t[:] Y_labels, + const intp_t[:] unique_Y_labels, float64_t radius, chunk_size=None, strategy=None, @@ -98,25 +98,24 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} else: self.weight_type = WeightingStrategy.callable - self.class_membership = class_membership - self.unique_labels = unique_labels - self.outlier_index = -1 - self.outlier_label_exists = False + self.Y_labels = Y_labels + self.unique_Y_labels = unique_Y_labels + self.outlier_label_index = -1 self.outliers_exist = False self.outlier_label = outlier_label + self.outlier_label_exists = outlier_label is not None self.outliers = np.zeros(self.n_samples_X, dtype=np.bool_) cdef intp_t idx - if outlier_label is not None: - self.outlier_label_exists = True - for idx in range(self.unique_labels.shape[0]): - if self.unique_labels[idx] == outlier_label: - self.outlier_index = idx + if self.outlier_label_exists: + for idx in range(self.unique_Y_labels.shape[0]): + if self.unique_Y_labels[idx] == outlier_label: + self.outlier_label_index = idx # Map from set of unique labels to their indices in `class_scores` # Buffer used in building a histogram for one-pass weighted mode self.class_scores = np.zeros( - (self.n_samples_X, unique_labels.shape[0]), dtype=np.float64, + (self.n_samples_X, unique_Y_labels.shape[0]), dtype=np.float64, ) @@ -137,8 +136,8 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} if k == 0: self.outliers_exist = True self.outliers[sample_index] = True - if self.outlier_index >= 0: - self.class_scores[sample_index][self.outlier_index] = score_incr + if self.outlier_label_index >= 0: + self.class_scores[sample_index][self.outlier_label_index] = score_incr return @@ -149,7 +148,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} score_incr = 1 / distances[neighbor_rank] neighbor_idx = indices[neighbor_rank] - neighbor_class_idx = self.class_membership[neighbor_idx] + neighbor_class_idx = self.Y_labels[neighbor_idx] self.class_scores[sample_index][neighbor_class_idx] += score_incr return @@ -174,6 +173,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} return + @final cdef void _parallel_on_Y_finalize( self, ) noexcept nogil: @@ -206,7 +206,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} % np.where(self.outliers)[0] ) - if self.outliers_exist and self.outlier_index < 0: + if self.outliers_exist and self.outlier_label_index < 0: warnings.warn( "Outlier label %s is not in training " "classes. All class probabilities of " diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 91327adbbbb2e..fdf820fa53169 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -859,8 +859,8 @@ def test_radius_neighbors_classmode_factory_method_wrong_usages(): radius = 5 metric = "manhattan" weights = "uniform" - labels = rng.randint(low=0, high=10, size=100) - unique_labels = np.unique(labels) + Y_labels = rng.randint(low=0, high=10, size=100) + unique_Y_labels = np.unique(Y_labels) msg = ( "Only float64 or float32 datasets pairs are supported at this time, " @@ -873,8 +873,8 @@ def test_radius_neighbors_classmode_factory_method_wrong_usages(): radius=radius, metric=metric, weights=weights, - labels=labels, - unique_labels=unique_labels, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, outlier_label=None, ) @@ -889,8 +889,8 @@ def test_radius_neighbors_classmode_factory_method_wrong_usages(): radius=radius, metric=metric, weights=weights, - labels=labels, - unique_labels=unique_labels, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, outlier_label=None, ) @@ -901,8 +901,8 @@ def test_radius_neighbors_classmode_factory_method_wrong_usages(): radius=-1, metric=metric, weights=weights, - labels=labels, - unique_labels=unique_labels, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, outlier_label=None, ) @@ -913,8 +913,8 @@ def test_radius_neighbors_classmode_factory_method_wrong_usages(): radius=-1, metric="wrong_metric", weights=weights, - labels=labels, - unique_labels=unique_labels, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, outlier_label=None, ) @@ -927,8 +927,8 @@ def test_radius_neighbors_classmode_factory_method_wrong_usages(): radius=radius, metric=metric, weights=weights, - labels=labels, - unique_labels=unique_labels, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, outlier_label=None, ) @@ -939,8 +939,8 @@ def test_radius_neighbors_classmode_factory_method_wrong_usages(): radius=radius, metric=metric, weights=weights, - labels=labels, - unique_labels=unique_labels, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, outlier_label=None, ) @@ -956,8 +956,8 @@ def test_radius_neighbors_classmode_factory_method_wrong_usages(): radius=radius, metric="wrong_metric", weights=non_existent_weights_strategy, - labels=labels, - unique_labels=unique_labels, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, outlier_label=None, ) @@ -1476,16 +1476,16 @@ def test_radius_neighbors_classmode_strategy_consistent(): metric = "manhattan" weights = "uniform" - labels = rng.randint(low=0, high=10, size=100) - unique_labels = np.unique(labels) + Y_labels = rng.randint(low=0, high=10, size=100) + unique_Y_labels = np.unique(Y_labels) results_X = RadiusNeighborsClassMode.compute( X=X, Y=Y, radius=radius, metric=metric, weights=weights, - labels=labels, - unique_labels=unique_labels, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, outlier_label=None, strategy="parallel_on_X", ) @@ -1495,8 +1495,8 @@ def test_radius_neighbors_classmode_strategy_consistent(): radius=radius, metric=metric, weights=weights, - labels=labels, - unique_labels=unique_labels, + Y_labels=Y_labels, + unique_Y_labels=unique_Y_labels, outlier_label=None, strategy="parallel_on_Y", ) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index 3851e01302279..2bbe5706ec548 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -467,6 +467,10 @@ class RadiusNeighborsClassifier(RadiusNeighborsMixin, ClassifierMixin, Neighbors - 'most_frequent' : assign the most frequent label of y to outliers. - None : when any outlier is detected, ValueError will be raised. + The outlier label should be selected from among the unique 'Y' labels. + If it is specified with a different value a warning will be raised and + all class probabilities of outliers will be assigned to be 0. + metric_params : dict, default=None Additional keyword arguments for the metric function. @@ -728,80 +732,87 @@ def predict_proba(self, X): Y=self._fit_X, radius=self.radius, weights=self.weights, - labels=self._y, - unique_labels=self.classes_, + Y_labels=self._y, + unique_Y_labels=self.classes_, outlier_label=self.outlier_label, metric=metric, metric_kwargs=metric_kwargs, strategy="parallel_on_X", + # `strategy="parallel_on_X"` has in practice be shown + # to be more efficient than `strategy="parallel_on_Y`` + # on many combination of datasets. + # Hence, we choose to enforce it here. + # For more information, see: + # https://github.com/scikit-learn/scikit-learn/pull/24076#issuecomment-1445258342 # noqa ) - else: - neigh_dist, neigh_ind = self.radius_neighbors(X) - outlier_mask = np.zeros(n_queries, dtype=bool) - outlier_mask[:] = [len(nind) == 0 for nind in neigh_ind] - outliers = np.flatnonzero(outlier_mask) - inliers = np.flatnonzero(~outlier_mask) - - classes_ = self.classes_ - _y = self._y - if not self.outputs_2d_: - _y = self._y.reshape((-1, 1)) - classes_ = [self.classes_] - - if self.outlier_label_ is None and outliers.size > 0: - raise ValueError( - "No neighbors found for test samples %r, " - "you can try using larger radius, " - "giving a label for outliers, " - "or considering removing them from your dataset." % outliers - ) + return probabilities - weights = _get_weights(neigh_dist, self.weights) - if weights is not None: - weights = weights[inliers] + neigh_dist, neigh_ind = self.radius_neighbors(X) + outlier_mask = np.zeros(n_queries, dtype=bool) + outlier_mask[:] = [len(nind) == 0 for nind in neigh_ind] + outliers = np.flatnonzero(outlier_mask) + inliers = np.flatnonzero(~outlier_mask) - probabilities = [] - # iterate over multi-output, measure probabilities of the k-th output. - for k, classes_k in enumerate(classes_): - pred_labels = np.zeros(len(neigh_ind), dtype=object) - pred_labels[:] = [_y[ind, k] for ind in neigh_ind] + classes_ = self.classes_ + _y = self._y + if not self.outputs_2d_: + _y = self._y.reshape((-1, 1)) + classes_ = [self.classes_] + + if self.outlier_label_ is None and outliers.size > 0: + raise ValueError( + "No neighbors found for test samples %r, " + "you can try using larger radius, " + "giving a label for outliers, " + "or considering removing them from your dataset." % outliers + ) + + weights = _get_weights(neigh_dist, self.weights) + if weights is not None: + weights = weights[inliers] + + probabilities = [] + # iterate over multi-output, measure probabilities of the k-th output. + for k, classes_k in enumerate(classes_): + pred_labels = np.zeros(len(neigh_ind), dtype=object) + pred_labels[:] = [_y[ind, k] for ind in neigh_ind] - proba_k = np.zeros((n_queries, classes_k.size)) - proba_inl = np.zeros((len(inliers), classes_k.size)) + proba_k = np.zeros((n_queries, classes_k.size)) + proba_inl = np.zeros((len(inliers), classes_k.size)) + + # samples have different size of neighbors within the same radius + if weights is None: + for i, idx in enumerate(pred_labels[inliers]): + proba_inl[i, :] = np.bincount(idx, minlength=classes_k.size) + else: + for i, idx in enumerate(pred_labels[inliers]): + proba_inl[i, :] = np.bincount( + idx, weights[i], minlength=classes_k.size + ) + proba_k[inliers, :] = proba_inl - # samples have different size of neighbors within the same radius - if weights is None: - for i, idx in enumerate(pred_labels[inliers]): - proba_inl[i, :] = np.bincount(idx, minlength=classes_k.size) + if outliers.size > 0: + _outlier_label = self.outlier_label_[k] + label_index = np.flatnonzero(classes_k == _outlier_label) + if label_index.size == 1: + proba_k[outliers, label_index[0]] = 1.0 else: - for i, idx in enumerate(pred_labels[inliers]): - proba_inl[i, :] = np.bincount( - idx, weights[i], minlength=classes_k.size - ) - proba_k[inliers, :] = proba_inl - - if outliers.size > 0: - _outlier_label = self.outlier_label_[k] - label_index = np.flatnonzero(classes_k == _outlier_label) - if label_index.size == 1: - proba_k[outliers, label_index[0]] = 1.0 - else: - warnings.warn( - "Outlier label {} is not in training " - "classes. All class probabilities of " - "outliers will be assigned with 0." - "".format(self.outlier_label_[k]) - ) - - # normalize 'votes' into real [0,1] probabilities - normalizer = proba_k.sum(axis=1)[:, np.newaxis] - normalizer[normalizer == 0.0] = 1.0 - proba_k /= normalizer - - probabilities.append(proba_k) - - if not self.outputs_2d_: - probabilities = probabilities[0] + warnings.warn( + "Outlier label {} is not in training " + "classes. All class probabilities of " + "outliers will be assigned with 0." + "".format(self.outlier_label_[k]) + ) + + # normalize 'votes' into real [0,1] probabilities + normalizer = proba_k.sum(axis=1)[:, np.newaxis] + normalizer[normalizer == 0.0] = 1.0 + proba_k /= normalizer + + probabilities.append(proba_k) + + if not self.outputs_2d_: + probabilities = probabilities[0] return probabilities From 32026e7769919b1e0c8897abc9b9d34a930625a6 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Wed, 9 Aug 2023 08:49:25 +0500 Subject: [PATCH 08/10] Import WeightingStrategy from the common _classmode --- .../_radius_neighbors_classmode.pyx.tp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp index 19f560f7f8e7a..81af8b187c028 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp @@ -9,11 +9,7 @@ from ...utils._typedefs cimport intp_t, float64_t import numpy as np from scipy.sparse import issparse from sklearn.utils.fixes import threadpool_limits - -cpdef enum WeightingStrategy: - uniform = 0 - distance = 1 - callable = 2 +from ._classmode cimport WeightingStrategy {{for name_suffix in ["32", "64"]}} From daf0443916e05da38ae378933c8af0faa3f70102 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Fri, 11 Aug 2023 12:10:34 +0500 Subject: [PATCH 09/10] Address PR suggestions --- .../_dispatcher.py | 44 +++---------------- .../_radius_neighbors_classmode.pyx.tp | 15 +++---- .../test_pairwise_distances_reduction.py | 7 +-- sklearn/neighbors/_classification.py | 2 +- 4 files changed, 17 insertions(+), 51 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py index 12a72dbfc8bc3..8efed77496f24 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py +++ b/sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py @@ -645,47 +645,13 @@ class RadiusNeighborsClassMode(BaseDistancesReductionDispatcher): @classmethod def valid_metrics(cls) -> List[str]: excluded = { - # PyFunc cannot be supported because it necessitates interacting with - # the CPython interpreter to call user defined functions. - "pyfunc", - "mahalanobis", # is numerically unstable - # In order to support discrete distance metrics, we need to have a - # stable simultaneous sort which preserves the order of the indices - # because there generally is a lot of occurrences for a given values - # of distances in this case. - # TODO: implement a stable simultaneous_sort. - "hamming", + # Euclidean is technically usable for RadiusNeighborsClassMode + # but it would not be competitive. + # TODO: implement Euclidean specialization using GEMM. "euclidean", - *BOOL_METRICS, + "sqeuclidean", } - return sorted(set(METRIC_MAPPING64.keys()) - excluded) - - @classmethod - def is_usable_for(cls, X, Y, metric) -> bool: - """Return True if the dispatcher can be used for the given parameters. - - Parameters - ---------- - X : ndarray of shape (n_samples_X, n_features) - The input array to be labelled. - - Y : ndarray of shape (n_samples_Y, n_features) - The input array whose labels are provided through the `Y_labels` - parameter. - - metric : str, default='euclidean' - The distance metric to use. For a list of available metrics, see - the documentation of :class:`~sklearn.metrics.DistanceMetric`. - Currently does not support `'precomputed'`. - - Returns - ------- - True if the PairwiseDistancesReduction can be used, else False. - """ - return ( - RadiusNeighbors.is_usable_for(X, Y, metric) - and metric in cls.valid_metrics() - ) + return sorted(set(BaseDistancesReductionDispatcher.valid_metrics()) - excluded) @classmethod def compute( diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp index 81af8b187c028..eeb086d91733a 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp @@ -3,13 +3,12 @@ import warnings from cython cimport floating, final, integral from cython.operator cimport dereference as deref from cython.parallel cimport parallel, prange - +from ._classmode cimport WeightingStrategy from ...utils._typedefs cimport intp_t, float64_t import numpy as np from scipy.sparse import issparse -from sklearn.utils.fixes import threadpool_limits -from ._classmode cimport WeightingStrategy +from ...utils.fixes import threadpool_limits {{for name_suffix in ["32", "64"]}} @@ -118,7 +117,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} cdef inline void weighted_histogram_mode( self, intp_t sample_index, - intp_t k, + intp_t sample_n_neighbors, intp_t* indices, float64_t* distances, ) noexcept nogil: @@ -129,7 +128,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} self.weight_type == WeightingStrategy.distance ) - if k == 0: + if sample_n_neighbors == 0: self.outliers_exist = True self.outliers[sample_index] = True if self.outlier_label_index >= 0: @@ -139,7 +138,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} # Iterate over the neighbors. This can be different for # each of the samples as they are based on the radius. - for neighbor_rank in range(k): + for neighbor_rank in range(sample_n_neighbors): if use_distance_weighting: score_incr = 1 / distances[neighbor_rank] @@ -162,7 +161,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} for idx in range(X_start, X_end): self.weighted_histogram_mode( sample_index=idx, - k=deref(self.neigh_indices)[idx].size(), + sample_n_neighbors=deref(self.neigh_indices)[idx].size(), indices=deref(self.neigh_indices)[idx].data(), distances=deref(self.neigh_distances)[idx].data(), ) @@ -185,7 +184,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} for idx in prange(self.n_samples_X, schedule='static'): self.weighted_histogram_mode( sample_index=idx, - k=deref(self.neigh_indices)[idx].size(), + sample_n_neighbors=deref(self.neigh_indices)[idx].size(), indices=deref(self.neigh_indices)[idx].data(), distances=deref(self.neigh_distances)[idx].data(), ) diff --git a/sklearn/metrics/tests/test_pairwise_distances_reduction.py b/sklearn/metrics/tests/test_pairwise_distances_reduction.py index 32b776ebf49fe..de9783a70cc0d 100644 --- a/sklearn/metrics/tests/test_pairwise_distances_reduction.py +++ b/sklearn/metrics/tests/test_pairwise_distances_reduction.py @@ -1468,7 +1468,8 @@ def test_argkmin_classmode_strategy_consistent(): assert_array_equal(results_X, results_Y) -def test_radius_neighbors_classmode_strategy_consistent(): +@pytest.mark.parametrize("outlier_label", [None, 0, 3, 6, 9]) +def test_radius_neighbors_classmode_strategy_consistent(outlier_label): rng = np.random.RandomState(1) X = rng.rand(100, 10) Y = rng.rand(100, 10) @@ -1486,7 +1487,7 @@ def test_radius_neighbors_classmode_strategy_consistent(): weights=weights, Y_labels=Y_labels, unique_Y_labels=unique_Y_labels, - outlier_label=None, + outlier_label=outlier_label, strategy="parallel_on_X", ) results_Y = RadiusNeighborsClassMode.compute( @@ -1497,7 +1498,7 @@ def test_radius_neighbors_classmode_strategy_consistent(): weights=weights, Y_labels=Y_labels, unique_Y_labels=unique_Y_labels, - outlier_label=None, + outlier_label=outlier_label, strategy="parallel_on_Y", ) assert_allclose(results_X, results_Y) diff --git a/sklearn/neighbors/_classification.py b/sklearn/neighbors/_classification.py index d09f54ed84e33..34a1b80e17862 100644 --- a/sklearn/neighbors/_classification.py +++ b/sklearn/neighbors/_classification.py @@ -743,7 +743,7 @@ def predict_proba(self, X): # on many combination of datasets. # Hence, we choose to enforce it here. # For more information, see: - # https://github.com/scikit-learn/scikit-learn/pull/24076#issuecomment-1445258342 # noqa + # https://github.com/scikit-learn/scikit-learn/pull/26828/files#r1282398471 # noqa ) return probabilities From a1e722d2ce05bb055b7c89377f8e11917dcf1b2b Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Mon, 25 Sep 2023 13:06:53 +0500 Subject: [PATCH 10/10] Updates: PR suggestions --- .../_radius_neighbors_classmode.pyx.tp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp index eeb086d91733a..25067b43cd20c 100644 --- a/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp +++ b/sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp @@ -20,14 +20,14 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} {{name_suffix}}bit implementation of RadiusNeighborsClassMode. """ cdef: - const intp_t[:] Y_labels - const intp_t[:] unique_Y_labels + const intp_t[::1] Y_labels + const intp_t[::1] unique_Y_labels intp_t outlier_label_index bint outlier_label_exists bint outliers_exist unsigned char[::1] outliers object outlier_label - float64_t[:, :] class_scores + float64_t[:, ::1] class_scores WeightingStrategy weight_type @classmethod @@ -71,8 +71,8 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} def __init__( self, DatasetsPair{{name_suffix}} datasets_pair, - const intp_t[:] Y_labels, - const intp_t[:] unique_Y_labels, + const intp_t[::1] Y_labels, + const intp_t[::1] unique_Y_labels, float64_t radius, chunk_size=None, strategy=None, @@ -98,11 +98,10 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} self.outlier_label_index = -1 self.outliers_exist = False self.outlier_label = outlier_label - self.outlier_label_exists = outlier_label is not None self.outliers = np.zeros(self.n_samples_X, dtype=np.bool_) cdef intp_t idx - if self.outlier_label_exists: + if self.outlier_label is not None: for idx in range(self.unique_Y_labels.shape[0]): if self.unique_Y_labels[idx] == outlier_label: self.outlier_label_index = idx @@ -192,7 +191,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix} return def _finalize_results(self): - if self.outliers_exist and not self.outlier_label_exists: + if self.outliers_exist and self.outlier_label is None: raise ValueError( "No neighbors found for test samples %r, " "you can try using larger radius, "