File tree Expand file tree Collapse file tree 2 files changed +33
-2
lines changed
_pairwise_distances_reduction Expand file tree Collapse file tree 2 files changed +33
-2
lines changed Original file line number Diff line number Diff line change @@ -110,7 +110,7 @@ def is_valid_sparse_matrix(X):
110110 X .indices .dtype == X .indptr .dtype == np .int32
111111 )
<
8000
/tr>112112
113- return (
113+ is_usable = (
114114 get_config ().get ("enable_cython_pairwise_dist" , True )
115115 and (is_numpy_c_ordered (X ) or is_valid_sparse_matrix (X ))
116116 and (is_numpy_c_ordered (Y ) or is_valid_sparse_matrix (Y ))
@@ -119,6 +119,24 @@ def is_valid_sparse_matrix(X):
119119 and metric in cls .valid_metrics ()
120120 )
121121
122+ # The other joblib-based back-end might be more efficient on fused sparse-dense
123+ # datasets' pairs on metric="(sq)euclidean" for some configurations because it
124+ # uses the Squared Euclidean matrix decomposition, i.e.:
125+ #
126+ # ||X_c_i - Y_c_j||² = ||X_c_i||² - 2 X_c_i.Y_c_j^T + ||Y_c_j||²
127+ #
128+ # calling efficient sparse-dense routines for matrix and vectors multiplication
129+ # implemented in SciPy we do not use yet here.
130+ # See: https://github.com/scikit-learn/scikit-learn/pull/23585#issuecomment-1247996669 # noqa
131+ # TODO: implement specialisation for (sq)euclidean on fused sparse-dense
132+ # using sparse-dense routines for matrix-vector multiplications.
133+ fused_sparse_dense_euclidean_case_guard = not (
134+ (is_valid_sparse_matrix (X ) or is_valid_sparse_matrix (Y ))
135+ and "euclidean" in metric
136+ )
137+
138+ return is_usable and fused_sparse_dense_euclidean_case_guard
139+
122140 @classmethod
123141 @abstractmethod
124142 def compute (
Original file line number Diff line number Diff line change @@ -518,7 +518,7 @@ def test_pairwise_distances_reduction_is_usable_for():
518518 Y = rng .rand (100 , 10 )
519519 X_csr = csr_matrix (X )
520520 Y_csr = csr_matrix (Y )
521- metric = "euclidean "
521+ metric = "manhattan "
522522
523523 # Must be usable for all possible pair of {dense, sparse} datasets
524524 assert BaseDistanceReductionDispatcher .is_usable_for (X , Y , metric )
@@ -551,6 +551,19 @@ def test_pairwise_distances_reduction_is_usable_for():
551551 np .asfortranarray (X ), Y , metric
552552 )
553553
554+ # We prefer not to use those implementations for fused sparse-dense when
555+ # metric="(sq)euclidean" because it's not yet the most efficient one on
556+ # all configurations of datasets.
557+ # See: https://github.com/scikit-learn/scikit-learn/pull/23585#issuecomment-1247996669 # noqa
558+ # TODO: implement specialisation for (sq)euclidean on fused sparse-dense
559+ # using sparse-dense routines for matrix-vector multiplications.
560+ assert not BaseDistanceReductionDispatcher .is_usable_for (
561+ X_csr , Y , metric = "euclidean"
562+ )
563+ assert not BaseDistanceReductionDispatcher .is_usable_for (
564+ X_csr , Y_csr , metric = "sqeuclidean"
565+ )
566+
554567 # CSR matrices without non-zeros elements aren't currently supported
555568 # TODO: support CSR matrices without non-zeros elements
556569 X_csr_0_nnz = csr_matrix (X * 0 )
You can’t perform that action at this time.
0 commit comments