8000 ENH Support for sparse matrices added to `sklearn.metrics.silhouette_… · Veghit/scikit-learn@3377338 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit 3377338

Browse files
awinmlSahil Guptathomasjpfanglemaitre
authored andcommitted
ENH Support for sparse matrices added to sklearn.metrics.silhouette_samples (scikit-learn#24677)
Co-authored-by: Sahil Gupta <sahil@Sahils-MBP.lan> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent bebe3b4 commit 3377338

File tree

3 files changed

+102
-27
lines changed

3 files changed

+102
-27
lines changed

doc/whats_new/v1.3.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@ Changelog
282282
- |Fix| :func:`metric.manhattan_distances` now supports readonly sparse datasets.
283283
:pr:`25432` by :user:`Julien Jerphanion <jjerphan>`.
284284

285+
- |Enhancement| :class:`metrics.silhouette_samples` nows accepts a sparse
286+
matrix of pairwise distances between samples, or a feature array.
287+
:pr:`18723` by :user:`Sahil Gupta <sahilgupta2105>` and
288+
:pr:`24677` by :user:`Ashwin Mathur <awinml>`.
289+
285290
- |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are
286291
not normalized, instead of actually normalizing them in the metric. Starting from
287292
1.5 this will raise an error. :pr:`25299` by :user:`Omar Salman <OmarManzoor`.

sklearn/metrics/cluster/_unsupervised.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import functools
1010

1111
import numpy as np
12+
from scipy.sparse import issparse
1213

1314
from ...utils import check_random_state
1415
from ...utils import check_X_y
@@ -122,31 +123,53 @@ def _silhouette_reduce(D_chunk, start, labels, label_freqs):
122123
123124
Parameters
124125
----------
125-
D_chunk : array-like of shape (n_chunk_samples, n_samples)
126-
Precomputed distances for a chunk.
126+
D_chunk : {array-like, sparse matrix} of shape (n_chunk_samples, n_samples)
127+
Precomputed distances for a chunk. If a sparse matrix is provided,
128+
only CSR format is accepted.
127129
start : int
128130
First index in the chunk.
129131
labels : array-like of shape (n_samples,)
130132
Corresponding cluster labels, encoded as {0, ..., n_clusters-1}.
131133
label_freqs : array-like
132134
Distribution of cluster labels in ``labels``.
133135
"""
136+
n_chunk_samples = D_chunk.shape[0]
134137
# accumulate distances from each sample to each cluster
135-
clust_dists = np.zeros((len(D_chunk), len(label_freqs)), dtype=D_chunk.dtype)
136-
for i in range(len(D_chunk)):
137-
clust_dists[i] += np.bincount(
138-
labels, weights=D_chunk[i], minlength=len(label_freqs)
139-
)
138+
cluster_distances = np.zeros(
139+
(n_chunk_samples, len(label_freqs)), dtype=D_chunk.dtype
140+
)
140141

141-
# intra_index selects intra-cluster distances within clust_dists
142-
intra_index = (np.arange(len(D_chunk)), labels[start : start + len(D_chunk)])
143-
# intra_clust_dists are averaged over cluster size outside this function
144-
intra_clust_dists = clust_dists[intra_index]
142+
if issparse(D_chunk):
143+
if D_chunk.format != "csr":
144+
raise TypeError(
145+
"Expected CSR matrix. Please pass sparse matrix in CSR format."
146+
)
147+
for i in range(n_chunk_samples):
148+
indptr = D_chunk.indptr
149+
indices = D_chunk.indices[indptr[i] : indptr[i + 1]]
150+
sample_weights = D_chunk.data[indptr[i] : indptr[i + 1]]
151+
sample_labels = np.take(labels, indices)
152+
cluster_distances[i] += np.bincount(
153+
sample_labels, weights=sample_weights, minlength=len(label_freqs)
154+
)
155+
else:
156+
for i in range(n_chunk_samples):
157+
sample_weights = D_chunk[i]
158+
sample_labels = labels
159+
cluster_distances[i] += np.bincount(
160+
sample_labels, weights=sample_weights, minlength=len(label_freqs)
161+
)
162+
163+
# intra_index selects intra-cluster distances within cluster_distances
164+
end = start + n_chunk_samples
165+
intra_index = (np.arange(n_chunk_samples), labels[start:end])
166+
# intra_cluster_distances are averaged over cluster size outside this function
167+
intra_cluster_distances = cluster_distances[intra_index]
145168
# of the remaining distances we normalise and extract the minimum
146-
clust_dists[intra_index] = np.inf
147-
clust_dists /= label_freqs
148-
inter_clust_dists = clust_dists.min(axis=1)
149-
return intra_clust_dists, inter_clust_dists
169+
cluster_distances[intra_index] = np.inf
170+
cluster_distances /= label_freqs
171+
inter_cluster_distances = cluster_distances.min(axis=1)
172+
return intra_cluster_distances, inter_cluster_distances
150173

151174

152175
def silhouette_samples(X, labels, *, metric="euclidean", **kwds):
@@ -174,9 +197,11 @@ def silhouette_samples(X, labels, *, metric="euclidean", **kwds):
174197
175198
Parameters
176199
----------
177-
X : array-like of shape (n_samples_a, n_samples_a) if metric == \
200+
X : {array-like, sparse matrix} of shape (n_samples_a, n_samples_a) if metric == \
178201
"precomputed" or (n_samples_a, n_features) otherwise
179-
An array of pairwise distances between samples, or a feature array.
202+
An array of pairwise distances between samples, or a feature array. If
203+
a sparse matrix is provided, CSR format should be favoured avoiding
204+
an additional copy.
180205
181206
labels : array-like of shape (n_samples,)
182207
Label values for each sample.
@@ -209,7 +234,7 @@ def silhouette_samples(X, labels, *, metric="euclidean", **kwds):
209234
.. [2] `Wikipedia entry on the Silhouette Coefficient
210235
<https://en.wikipedia.org/wiki/Silhouette_(clustering)>`_
211236
"""
212-
X, labels = check_X_y(X, labels, accept_sparse=["csc", "csr"])
237+
X, labels = check_X_y(X, labels, accept_sparse=["csr"])
213238

214239
# Check for non-zero diagonal entries in precomputed distance matrix
215240
if metric == "precomputed":
@@ -219,10 +244,10 @@ def silhouette_samples(X, labels, *, metric="euclidean", **kwds):
219244
)
220245
if X.dtype.kind == "f":
221246
atol = np.finfo(X.dtype).eps * 100
222-
if np.any(np.abs(np.diagonal(X)) > atol):
223-
raise ValueError(error_msg)
224-
elif np.any(np.diagonal(X) != 0): # integral dtype
225-
raise ValueError(error_msg)
247+
if np.any(np.abs(X.diagonal()) > atol):
248+
raise error_msg
249+
elif np.any(X.diagonal() != 0): # integral dtype
250+
raise error_msg
226251

227252
le = LabelEncoder()
228253
labels = le.fit_transform(labels)

sklearn/metrics/cluster/tests/test_unsupervised.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import warnings
22

33
import numpy as np
4-
import scipy.sparse as sp
54
import pytest
6-
from scipy.sparse import csr_matrix
5+
6+
from numpy.testing import assert_allclose
7+
from scipy.sparse import csr_matrix, csc_matrix, dok_matrix, lil_matrix
8+
from scipy.sparse import issparse
79

810
from sklearn import datasets
911
from sklearn.utils._testing import assert_array_equal
1012
from sklearn.metrics.cluster import silhouette_score
1113
from sklearn.metrics.cluster import silhouette_samples
14+
from sklearn.metrics.cluster._unsupervised import _silhouette_reduce
1215
from sklearn.metrics import pairwise_distances
1316
from sklearn.metrics.cluster import calinski_harabasz_score
1417
from sklearn.metrics.cluster import davies_bouldin_score
@@ -19,11 +22,12 @@ def test_silhouette():
1922
dataset = datasets.load_iris()
2023
X_dense = dataset.data
2124
X_csr = csr_matrix(X_dense)
22-
X_dok = sp.dok_matrix(X_dense)
23-
X_lil = sp.lil_matrix(X_dense)
25+
X_csc = csc_matrix(X_dense)
26+
X_dok = dok_matrix(X_dense)
27+
X_lil = lil_matrix(X_dense)
2428
y = dataset.target
2529

26-
for X in [X_dense, X_csr, X_dok, X_lil]:
30+
for X in [X_dense, X_csr, X_csc, X_dok, X_lil]:
2731
D = pairwise_distances(X, metric="euclidean")
2832
# Given that the actual labels are used, we can assume that S would be
2933
# positive.
@@ -282,6 +286,47 @@ def test_silhouette_nonzero_diag(dtype):
282286
silhouette_samples(dists, labels, metric="precomputed")
283287

284288

289+
@pytest.mark.parametrize("to_sparse", (csr_matrix, csc_matrix, dok_matrix, lil_matrix))
290+
def test_silhouette_samples_precomputed_sparse(to_sparse):
291+
"""Check that silhouette_samples works for sparse matrices correctly."""
292+
X = np.array([[0.2, 0.1, 0.1, 0.2, 0.1, 1.6, 0.2, 0.1]], dtype=np.float32).T
293+
y = [0, 0, 0, 0, 1, 1, 1, 1]
294+
pdist_dense = pairwise_distances(X)
295+
pdist_sparse = to_sparse(pdist_dense)
296+
assert issparse(pdist_sparse)
297+
output_with_sparse_input = silhouette_samples(pdist_sparse, y, metric="precomputed")
298+
output_with_dense_input = silhouette_samples(pdist_dense, y, metric="precomputed")
299+
assert_allclose(output_with_sparse_input, output_with_dense_input)
300+
301+
302+
@pytest.mark.parametrize("to_sparse", (csr_matrix, csc_matrix, dok_matrix, lil_matrix))
303+
def test_silhouette_samples_euclidean_sparse(to_sparse):
304+
"""Check that silhouette_samples works for sparse matrices correctly."""
305+
X = np.array([[0.2, 0.1, 0.1, 0.2, 0.1, 1.6, 0.2, 0.1]], dtype=np.float32).T
306+
y = [0, 0, 0, 0, 1, 1, 1, 1]
307+
pdist_dense = pairwise_distances(X)
308+
pdist_sparse = to_sparse(pdist_dense)
309+
assert issparse(pdist_sparse)
310+
output_with_sparse_input = silhouette_samples(pdist_sparse, y)
311+
output_with_dense_input = silhouette_samples(pdist_dense, y)
312+
assert_allclose(output_with_sparse_input, output_with_dense_input)
313+
314+
315+
@pytest.mark.parametrize("to_non_csr_sparse", (csc_matrix, dok_matrix, lil_matrix))
316+
def test_silhouette_reduce(to_non_csr_sparse):
317+
"""Check for non-CSR input to private method `_silhouette_reduce`."""
318+
X = np.array([[0.2, 0.1, 0.1, 0.2, 0.1, 1.6, 0.2, 0.1]], dtype=np.float32).T
319+
pdist_dense = pairwise_distances(X)
320+
pdist_sparse = to_non_csr_sparse(pdist_dense)
321+
y = [0, 0, 0, 0, 1, 1, 1, 1]
322+
label_freqs = np.bincount(y)
323+
with pytest.raises(
324+
TypeError,
325+
match="Expected CSR matrix. Please pass sparse matrix in CSR format.",
326+
):
327+
_silhouette_reduce(pdist_sparse, start=0, labels=y, label_freqs=label_freqs)
328+
329+
285330
def assert_raises_on_only_one_label(func):
286331
"""Assert message when there is only one label"""
287332
rng = np.random.RandomState(seed=0)

0 commit comments

Comments
 (0)
0