@@ -35,10 +35,19 @@ def _scale_normalize(X):
3535
3636 """
3737 X = make_nonnegative (X )
38- row_diag = np .asarray (1.0 / np .sqrt (X .sum (axis = 1 ))).squeeze ()
39- col_diag = np .asarray (1.0 / np .sqrt (X .sum (axis = 0 ))).squeeze ()
38+ row_diag = np .asarray (1.0 / np .sqrt (X .sum (axis = 1 )))
39+ if row_diag .shape [0 ] != 1 :
40+ row_diag = row_diag .squeeze ()
41+
42+ col_diag = np .asarray (1.0 / np .sqrt (X .sum (axis = 0 )))
43+ if col_diag .ndim == 1 and col_diag .shape [0 ]!= 1 :
44+ col_diag = col_diag .squeeze ()
45+ if col_diag .ndim == 2 and col_diag .shape [0 ]== 1 and col_diag .shape [1 ]!= 1 :
46+ col_diag = col_diag .squeeze ()
47+
4048 row_diag = np .where (np .isnan (row_diag ), 0 , row_diag )
4149 col_diag = np .where (np .isnan (col_diag ), 0 , col_diag )
50+
4251 if issparse (X ):
4352 n_rows , n_cols = X .shape
4453 r = dia_matrix ((row_diag , [0 ]), shape = (n_rows , n_rows ))
@@ -160,6 +169,8 @@ def _svd(self, array, n_components, n_discard):
160169
161170 assert_all_finite (u )
162171 assert_all_finite (vt )
172+ if u .shape [1 ] == 1 and vt .shape [0 ] == 1 :
173+ n_discard = 0
163174 u = u [:, n_discard :]
164175 vt = vt [n_discard :]
165176 return u , vt .T
@@ -282,9 +293,10 @@ def _fit(self, X):
282293 normalized_data , row_diag , col_diag = _scale_normalize (X )
283294 n_sv = 1 + int (np .ceil (np .log2 (self .n_clusters )))
284295 u , v = self ._svd (normalized_data , n_sv , n_discard = 1 )
296+
285297 z = np .vstack ((row_diag [:, np .newaxis ] * u ,
286298 col_diag [:, np .newaxis ] * v ))
287-
299+
288300 _ , labels = self ._k_means (z , self .n_clusters )
289301
290302 n_rows = X .shape [0 ]
0 commit comments