8000 MAINT: Return self for fit in Spectral Biclustering and CoClustering · scikit-learn/scikit-learn@4dd790a · GitHub
[go: up one dir, main page]

Skip to content

Commit 4dd790a

Browse files
committed
MAINT: Return self for fit in Spectral Biclustering and CoClustering
Closes gh-6126
1 parent 45e3431 commit 4dd790a

File tree

3 files changed

+32
-9
lines changed

3 files changed

+32
-9
lines changed

sklearn/cluster/bicluster.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,19 @@ def _scale_normalize(X):
3535
3636
"""
3737
X = make_nonnegative(X)
38-
row_diag = np.asarray(1.0 / np.sqrt(X.sum(axis=1))).squeeze()
39-
col_diag = np.asarray(1.0 / np.sqrt(X.sum(axis=0))).squeeze()
38+
row_diag = np.asarray(1.0 / np.sqrt(X.sum(axis=1)))
39+
if row_diag.shape[0] != 1:
40+
row_diag = row_diag.squeeze()
41+
42+
col_diag = np.asarray(1.0 / np.sqrt(X.sum(axis=0)))
43+
if col_diag.ndim == 1 and col_diag.shape[0] != 1:
44+
col_diag = col_diag.squeeze()
45+
if col_diag.ndim == 2 and col_diag.shape[1] != 1:
46+
col_diag = col_diag.squeeze()
47+
4048
row_diag = np.where(np.isnan(row_diag), 0, row_diag)
4149
col_diag = np.where(np.isnan(col_diag), 0, col_diag)
50+
4251
if issparse(X):
4352
n_rows, n_cols = X.shape
4453
r = dia_matrix((row_diag, [0]), shape=(n_rows, n_rows))
@@ -110,17 +119,21 @@ def _check_parameters(self):
110119
" one of {1}.".format(self.svd_method,
111120
legal_svd_methods))
112121

113-
def fit(self, X):
122+
def fit(self, X, y=None):
114123
"""Creates a biclustering for X.
115124
116125
Parameters
117126
----------
118127
X : array-like, shape (n_samples, n_features)
119128
129+
Returns
130+
-------
131+
self : object
132+
Returns the instance itself.
120133
"""
121134
X = check_array(X, accept_sparse='csr', dtype=np.float64)
122135
self._check_parameters()
123-
self._fit(X)
136+
return self._fit(X)
124137

125138
def _svd(self, array, n_components, n_discard):
126139
"""Returns first `n_components` left and right singular
@@ -156,6 +169,8 @@ def _svd(self, array, n_components, n_discard):
156169

157170
assert_all_finite(u)
158171
assert_all_finite(vt)
172+
if u.shape[1] == 1 and vt.shape[0] == 1:
173+
n_discard = 0
159174
u = u[:, n_discard:]
160175
vt = vt[n_discard:]
161176
return u, vt.T
@@ -278,6 +293,7 @@ def _fit(self, X):
278293
normalized_data, row_diag, col_diag = _scale_normalize(X)
279294
n_sv = 1 + int(np.ceil(np.log2(self.n_clusters)))
280295
u, v = self._svd(normalized_data, n_sv, n_discard=1)
296+
281297
z = np.vstack((row_diag[:, np.newaxis] * u,
282298
col_diag[:, np.newaxis] * v))
283299

@@ -291,6 +307,7 @@ def _fit(self, X):
291307
for c in range(self.n_clusters))
292308
self.columns_ = np.vstack(self.column_labels_ == c
293309
for c in range(self.n_clusters))
310+
return self
294311

295312

296313
class SpectralBiclustering(BaseSpectral):
@@ -475,6 +492,7 @@ def _fit(self, X):
475492
self.columns_ = np.vstack(self.column_labels_ == label
476493
for _ in range(n_row_clusters)
477494
for label in range(n_col_clusters))
495+
return self
478496

479497
def _fit_best_piecewise(self, vectors, n_best, n_clusters):
480498
"""Find the ``n_best`` vectors that are best approximated by piecewise

sklearn/tests/test_common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from sklearn.utils.testing import _named_check
2424

2525
import sklearn
26+
2627
from sklearn.cluster.bicluster import BiclusterMixin
28+
from sklearn.decomposition import ProjectedGradientNMF
2729

2830
from sklearn.linear_model.base import LinearClassifierMixin
2931
from sklearn.utils.estimator_checks import (
@@ -63,8 +65,6 @@ def test_non_meta_estimators():
6365
# input validation etc for non-meta estimators
6466
estimators = all_estimators()
6567
for name, Estimator in estimators:
66-
if issubclass(Estimator, BiclusterMixin):
67-
continue
6868
if name.startswith("_"):
6969
continue
7070
for check in _yield_all_checks(name, Estimator):
@@ -214,6 +214,7 @@ def test_transformer_n_iter():
214214
check_transformer_n_iter, name), name, estimator
215215

216216

217+
217218
def test_get_params_invariance():
218219
# Test for estimators that support get_params, that
219220
# get_params(deep=False) is a subset of get_params(deep=True)

sklearn/utils/estimator_checks.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,9 @@ def set_testing_parameters(estimator):
288288
if "decision_function_shape" in params:
289289
# SVC
290290
estimator.set_params(decision_function_shape='ovo')
291-
291+
if "n_best" in params:
292+
# BiCluster
293+
estimator.set_params(n_best=1)
292294
if estimator.__class__.__name__ == "SelectFdr":
293295
# be tolerant of noisy datasets (not actually speed)
294296
estimator.set_params(alpha=.5)
@@ -335,6 +337,8 @@ def check_estimator_sparse_data(name, Estimator):
335337
X_csr = sparse.csr_matrix(X)
336338
y = (4 * rng.rand(40)).astype(np.int)
337339
for sparse_format in ['csr', 'csc', 'dok', 'lil', 'coo', 'dia', 'bsr']:
340+
if name == 'SpectralCoclustering':
341+
continue
338342
X = X_csr.asformat(sparse_format)
339343
# catch deprecation warnings
340344
with ignore_warnings(category=DeprecationWarning):
@@ -684,7 +688,7 @@ def check_fit_score_takes_y(name, Estimator):
684688
@ignore_warnings
685689
def check_estimators_dtypes(name, Estimator):
686690
rnd = np.random.RandomState(0)
687-
X_train_32 = 3 * rnd.uniform(size=(20, 5)).astype(np.float32)
691+
X_train_32 = 3 * rnd.uniform(1.0, 2.0, size=(20, 5)).astype(np.float32)
688692
X_train_64 = X_train_32.astype(np.float64)
689693
X_train_int_64 = X_train_32.astype(np.int64)
690694
X_train_int_32 = X_train_32.astype(np.int32)
@@ -1309,7 +1313,7 @@ def check_class_weight_balanced_linear_classifier(name, Classifier):
13091313

13101314
@ignore_warnings(category=DeprecationWarning)
13111315
def check_estimators_overwrite_params(name, Estimator):
1312-
X, y = make_blobs(random_state=0, n_samples=9)
1316+
X, y = make_blobs(random_state=0, n_samples=9, n_features=3)
13131317
y = multioutput_estimator_convert_y_2d(name, y)
13141318
# some want non-negative input
13151319
X -= X.min()

0 commit comments

Comments
 (0)
0