8000 TST add test for sparse matrix handling in clustering · scikit-learn/scikit-learn@494b8e5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 494b8e5

Browse files
committed
TST add test for sparse matrix handling in clustering
Make sparsity check check everything. don't test everything. That would be nice but is out of scope :-/ catch special case of no core samples in DBSCAN add nonregression test for sparse dbscan with no core samples.
1 parent a413f87 commit 494b8e5

File tree

5 files changed

+39
-54
lines changed

5 files changed

+39
-54
lines changed

sklearn/cluster/dbscan_.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,12 @@ def fit(self, X, sample_weight=None):
240240
X = check_array(X, accept_sparse='csr')
241241
clust = dbscan(X, sample_weight=sample_weight, **self.get_params())
242242
self.core_sample_indices_, self.labels_ = clust
243-
self.components_ = X[self.core_sample_indices_].copy()
243+
if len(self.core_sample_indices_):
244+
# fix for scipy sparse indexing issue
245+
self.components_ = X[self.core_sample_indices_].copy()
246+
else:
247+
# no core samples
248+
self.components_ = np.empty((0, X.shape[1]))
244249
return self
245250

246251
def fit_predict(self, X, y=None, sample_weight=None):

sklearn/cluster/mean_shift_.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
from collections import defaultdict
2020
from ..externals import six
21-
from ..utils import extmath, check_random_state, gen_batches
2221
from ..utils.validation import check_is_fitted
22+
from ..utils import extmath, check_random_state, gen_batches, check_array
2323
from ..base import BaseEstimator, ClusterMixin
2424
from ..neighbors import NearestNeighbors
2525
from ..metrics.pairwise import pairwise_distances_argmin
@@ -328,7 +328,7 @@ def fit(self, X):
328328
X : array-like, shape=[n_samples, n_features]
329329
Samples to cluster.
330330
"""
331-
X = np.asarray(X)
331+
X = check_array(X)
332332
self.cluster_centers_, self.labels_ = \
333333
mean_shift(X, bandwidth=self.bandwidth, seeds=self.seeds,
334334
min_bin_freq=self.min_bin_freq,

sklearn/cluster/tests/test_dbscan.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pickle
66

77
import numpy as np
8-
from numpy.testing import assert_raises
98

109
from scipy.spatial import distance
1110
from scipy import sparse
@@ -78,6 +77,18 @@ def test_dbscan_sparse():
7877
assert_array_equal(labels_dense, labels_sparse)
7978

8079

80+
def test_dbscan_no_core_samples():
81+
rng = np.random.RandomState(0)
82+
X = rng.rand(40, 10)
83+
X[X < .8] = 0
84+
85+
for X_ in [X, sparse.csr_matrix(X)]:
86+
db = DBSCAN().fit(X_)
87+
assert_array_equal(db.components_, np.empty((0, X_.shape[1])))
88+
assert_array_equal(db.labels_, -1)
89+
assert_equal(db.core_sample_indices_.shape, (0,))
90+
91+
8192
def test_dbscan_callable():
8293
"""Tests the DBSCAN algorithm with a callable metric."""
8394
# Parameters chosen specifically for this task.

sklearn/tests/test_common.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.externals.six.moves import zip
1717
from sklearn.utils.testing import assert_false, clean_warning_registry
1818
from sklearn.utils.testing import all_estimators
19-
from sklearn.utils.testing import assert_greater
19+
from sklearn.utils.testing import assert_greater
2020
from sklearn.utils.testing import assert_in
2121
from sklearn.utils.testing import SkipTest
2222
from sklearn.utils.testing import ignore_warnings
@@ -29,14 +29,13 @@
2929
from sklearn.linear_model.base import LinearClassifierMixin
3030
from sklearn.utils.estimator_checks import (
3131
check_parameters_default_constructible,
32-
check_regressors_classifiers_sparse_data,
32+
check_estimator_sparse_data,
3333
check_transformer,
3434
check_clustering,
3535
check_clusterer_compute_labels_predict,
3636
check_regressors_int,
3737
check_regressors_train,
3838
check_regressors_pickle,
39-
check_transformer_sparse_data,
4039
check_transformer_pickle,
4140
check_transformers_unfitted,
4241
check_estimators_nan_inf,
@@ -100,13 +99,9 @@ def test_non_meta_estimators():
10099
if hasattr(Estimator, 'sparsify'):
101100
yield check_sparsify_coefficients, name, Estimator
102101

103-
104-
def test_estimators_sparse_data():
105-
# All estimators should either deal with sparse data or raise an
106-
# exception with type TypeError and an intelligible error message
107-
estimators = all_estimators(type_filter=['classifier', 'regressor'])
108-
for name, Estimator in estimators:
109-
yield check_regressors_classifiers_sparse_data, name, Estimator
102+
yield check_estimator_sparse_data, name, Estimator
103+
if name not in CROSS_DECOMPOSITION + ['Imputer']:
104+
yield check_estimators_nan_inf, name, Estimator
110105

111106

112107
def test_transformers():
@@ -116,7 +111,6 @@ def test_transformers():
116111
for name, Transformer in transformers:
117112
# All transformers should either deal with sparse data or raise an
118113
# exception with type TypeError and an intelligible error message
119-
yield check_transformer_sparse_data, name, Transformer
120114
yield check_transformer_pickle, name, Transformer
121115
if name not in ['AdditiveChi2Sampler', 'Binarizer', 'Normalizer',
122116
'PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']:

sklearn/utils/estimator_checks.py

Lines changed: 14 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -116,20 +116,27 @@ def _is_32bit():
116116
return struct.calcsize('P') * 8 == 32
117117

118118

119-
def check_regressors_classifiers_sparse_data(name, Estimator):
119+
def check_estimator_sparse_data(name, Estimator):
120120
rng = np.random.RandomState(0)
121121
X = rng.rand(40, 10)
122122
X[X < .8] = 0
123123
X = sparse.csr_matrix(X)
124124
y = (4 * rng.rand(40)).astype(np.int)
125125
# catch deprecation warnings
126126
with warnings.catch_warnings():
127-
estimator = Estimator()
127+
if name in ['Scaler', 'StandardScaler']:
128+
estimator = Estimator(with_mean=False)
129+
else:
130+
estimator = Estimator()
128131
set_fast_parameters(estimator)
129132
# fit and predict
130133
try:
131-
estimator.fit(X, y)
132-
estimator.predict(X)
134+
if is_supervised(estimator):
135+
estimator.fit(X, y)
136+
else:
137+
estimator.fit(X)
138+
if hasattr(estimator, "predict"):
139+
estimator.predict(X)
133140
if hasattr(estimator, 'predict_proba'):
134141
estimator.predict_proba(X)
135142
except TypeError as e:
@@ -245,38 +252,6 @@ def _check_transformer(name, Transformer, X, y):
245252
assert_raises(ValueError, transformer.transform, X.T)
246253

247254

248-
def check_transformer_sparse_data(name, Transformer):
249-
rng = np.random.RandomState(0)
250-
X = rng.rand(40, 10)
251-
X[X < .8] = 0
252-
X = sparse.csr_matrix(X)
253-
y = (4 * rng.rand(40)).astype(np.int)
254-
# catch deprecation warnings
255-
with warnings.catch_warnings(record=True):
256-
if name in ['Scaler', 'StandardScaler']:
257-
transformer = Transformer(with_mean=False)
258-
else:
259-
transformer = Transformer()
260-
261-
set_fast_parameters(transformer)
262-
263-
# fit
264-
try:
265-
transformer.fit(X, y)
266-
except TypeError as e:
267-
if not 'sparse' in repr(e):
268-
print("Estimator %s doesn't seem to fail gracefully on "
269-
"sparse data: error message state explicitly that "
270-
"sparse input is not supported if this is not the case."
271-
% name)
272-
raise
273-
except Exception:
274-
print("Estimator %s doesn't seem to fail gracefully on "
275-
"sparse data: it should raise a TypeError if sparse input "
276-
"is explicitly not supported." % name)
277-
raise
278-
279-
280255
def check_estimators_nan_inf(name, Estimator):
281256
rnd = np.random.RandomState(0)
282257
X_train_finite = rnd.uniform(size=(10, 3))
@@ -567,7 +542,7 @@ def check_estimators_unfitted(name, Estimator):
567542
est = Estimator()
568543

569544
assert_raises(NotFittedError, est.predict, X)
570-
545+
571546
if hasattr(est, 'predict'):
572547
assert_raises(NotFittedError, est.predict, X)
573548

@@ -576,7 +551,7 @@ def check_estimators_unfitted(name, Estimator):
576551

577552
if hasattr(est, 'predict_proba'):
578553
assert_raises(NotFittedError, est.predict_proba, X)
579-
554+
580555
if hasattr(est, 'predict_log_proba'):
581556
assert_raises(NotFittedError, est.predict_log_proba, X)
582557

@@ -991,7 +966,7 @@ def multioutput_estimator_convert_y_2d(name, y):
991966
return y
992967

993968

994-
def check_non_transformer_estimators_n_iter(name, estimator,
969+
def check_non_transformer_estimators_n_iter(name, estimator,
995970
multi_output=False):
996971
# Check if all iterative solvers, run for more than one iteratiom
997972

0 commit comments

Comments
 (0)
0