@@ -412,6 +412,8 @@ def dict_learning(X, n_components, alpha, max_iter=100, tol=1e-8,
412
412
SparsePCA
413
413
MiniBatchSparsePCA
414
414
"""
415
+ if n_components is None :
416
+ n_components = X .shape [1 ]
415
417
416
418
if method not in ('lars' , 'cd' ):
417
419
raise ValueError ('Coding method %r not supported as a fit algorithm.'
@@ -750,7 +752,7 @@ def transform(self, X, y=None):
750
752
Transformed data
751
753
752
754
"""
753
- check_is_fitted (self , 'components_' )
755
+ check_is_fitted (self , 'components_' )
754
756
755
757
# XXX : kwargs is not documented
756
758
X = check_array (X )
@@ -1159,13 +1161,9 @@ def fit(self, X, y=None):
1159
1161
"""
1160
1162
random_state = check_random_state (self .random_state )
1161
1163
X = check_array (X )
1162
- if self .n_components is None :
1163
- n_components = X .shape [1 ]
1164
- else :
1165
- n_components = self .n_components
1166
1164
1167
1165
U , (A , B ), self .n_iter_ = dict_learning_online (
1168
- X , n_components , self .alpha ,
1166
+ X , self . n_components , self .alpha ,
1169
1167
n_iter = self .n_iter , return_code = False ,
1170
1168
method = self .fit_algorithm ,
1171
1169
n_jobs = self .n_jobs , dict_init = self .dict_init ,
0 commit comments