8000 MAINT Make `ArgKminClassMode` accept sparse datasets (#27018) · scikit-learn/scikit-learn@f5c4999 · GitHub
[go: up one dir, main page]

Skip to content

Commit f5c4999

Browse files
jjerphanMicky774
andauthored
MAINT Make ArgKminClassMode accept sparse datasets (#27018)
Co-authored-by: Meekail Zain <Micky774@users.noreply.github.com>
1 parent 06a6b93 commit f5c4999

File tree

2 files changed

+14
-29
lines changed

2 files changed

+14
-29
lines changed

doc/whats_new/v1.4.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ Changelog
170170
:mod:`sklearn.neighbors`
171171
........................
172172

173+
- |Efficiency| :meth:`sklearn.neighbors.KNeighborsRegressor.predict` and
174+
:meth:`sklearn.neighbors.KNeighborsRegressor.predict_proba` now efficiently support
175+
pairs of dense and sparse datasets.
176+
:pr:`27018` by :user:`Julien Jerphanion <jjerphan>`.
177+
173178
- |Fix| Neighbors based estimators now correctly work when `metric="minkowski"` and the
174179
metric parameter `p` is in the range `0 < p < 1`, regardless of the `dtype` of `X`.
175180
:pr:`26760` by :user:`Shreesha Kumar Bhat <Shreesha3112>`.

sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -451,35 +451,15 @@ class ArgKminClassMode(BaseDistancesReductionDispatcher):
451451
"""
452452

453453
@classmethod
454-
def is_usable_for(cls, X, Y, metric) -> bool:
455-
"""Return True if the dispatcher can be used for the given parameters.
456-
457-
Parameters
458-
----------
459-
X : ndarray of shape (n_samples_X, n_features)
460-
The input array to be labelled.
461-
462-
Y : ndarray of shape (n_samples_Y, n_features)
463-
The input array whose labels are provided through the `Y_labels`
464-
parameter.
465-
466-
metric : str, default='euclidean'
467-
The distance metric to use. For a list of available metrics, see
468-
the documentation of :class:`~sklearn.metrics.DistanceMetric`.
469-
Currently does not support `'precomputed'`.
470-
471-
Returns
472-
-------
473-
True if the PairwiseDistancesReduction can be used, else False.
474-
"""
475-
return (
476-
ArgKmin.is_usable_for(X, Y, metric)
477-
# TODO: Support CSR matrices.
478-
and not issparse(X)
479-
and not issparse(Y)
480-
# TODO: implement Euclidean specialization with GEMM.
481-
and metric not in ("euclidean", "sqeuclidean")
482-
)
454+
def valid_metrics(cls) -> List[str]:
455+
excluded = {
456+
# Euclidean is technically usable for ArgKminClassMode
457+
# but its current implementation would not be competitive.
458+
# TODO: implement Euclidean specialization using GEMM.
459+
"euclidean",
460+
"sqeuclidean",
461+
}
462+
return list(set(BaseDistancesReductionDispatcher.valid_metrics()) - excluded)
483463

484464
@classmethod
485465
def compute(

0 commit comments

Comments
 (0)
0