10BC0 Attempt to fix tests · scikit-learn/scikit-learn@5085d2b · GitHub
[go: up one dir, main page]

Skip to content

Commit 5085d2b

Browse files
committed
Attempt to fix tests
1 parent f2988bd commit 5085d2b

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

sklearn/cluster/bicluster.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,19 @@ def _scale_normalize(X):
3535
3636
"""
3737
X = make_nonnegative(X)
38-
row_diag = np.asarray(1.0 / np.sqrt(X.sum(axis=1))).squeeze()
39-
col_diag = np.asarray(1.0 / np.sqrt(X.sum(axis=0))).squeeze()
38+
row_diag = np.asarray(1.0 / np.sqrt(X.sum(axis=1)))
39+
if row_diag.shape[0] != 1:
40+
row_diag = row_diag.squeeze()
41+
42+
col_diag = np.asarray(1.0 / np.sqrt(X.sum(axis=0)))
43+
if col_diag.ndim == 1 and col_diag.shape[0]!=1 :
44+
col_diag = col_diag.squeeze()
45+
if col_diag.ndim == 2 and col_diag.shape[0]==1 and col_diag.shape[1]!=1 :
46+
col_diag = col_diag.squeeze()
47+
4048
row_diag = np.where(np.isnan(row_diag), 0, row_diag)
4149
col_diag = np.where(np.isnan(col_diag), 0, col_diag)
50+
4251
if issparse(X):
4352
n_rows, n_cols = X.shape
4453
r = dia_matrix((row_diag, [0]), shape=(n_rows, n_rows))
@@ -160,6 +169,8 @@ def _svd(self, array, n_components, n_discard):
160169

161170
assert_all_finite(u)
162171
assert_all_finite(vt)
172+
if u.shape[1] == 1 and vt.shape[0] == 1:
173+
n_discard = 0
163174
u = u[:, n_discard:]
164175
vt = vt[n_discard:]
165176
return u, vt.T
@@ -282,9 +293,10 @@ def _fit(self, X):
282293
normalized_data, row_diag, col_diag = _scale_normalize(X)
283294
n_sv = 1 + int(np.ceil(np.log2(self.n_clusters)))
284295
u, v = self._svd(normalized_data, n_sv, n_discard=1)
296+
285297
z = np.vstack((row_diag[:, np.newaxis] * u,
286298
col_diag[:, np.newaxis] * v))
287-
299+
288300
_, labels = self._k_means(z, self.n_clusters)
289301

290302
n_rows = X.shape[0]

sklearn/utils/estimator_checks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ def check_estimator_sparse_data(name, Estimator):
336336
X_csr = sparse.csr_matrix(X)
337337
y = (4 * rng.rand(40)).astype(np.int)
338338
for sparse_format in ['csr', 'csc', 'dok', 'lil', 'coo', 'dia', 'bsr']:
339+
if name == 'SpectralCoclustering':
340+
continue
339341
X = X_csr.asformat(sparse_format)
340342
# catch deprecation warnings
341343
with warnings.catch_warnings():
@@ -683,7 +685,7 @@ def check_fit_score_takes_y(name, Estimator):
683685
@ignore_warnings
684686
def check_estimators_dtypes(name, Estimator):
685687
rnd = np.random.RandomState(0)
686-
X_train_32 = 3 * rnd.uniform(size< 51E2 span class="pl-c1">=(20, 5)).astype(np.float32)
688+
X_train_32 = 3 * rnd.uniform(1.0, 2.0, size=(20, 5)).astype(np.float32)
687689
X_train_64 = X_train_32.astype(np.float64)
688690
X_train_int_64 = X_train_32.astype(np.int64)
689691
X_train_int_32 = X_train_32.astype(np.int32)

0 commit comments

Comments
 (0)
0