8000 Fix n_components=None not being handled in MiniBatchDictionaryLearnin… · scikit-learn/scikit-learn@fa184fe · GitHub
[go: up one dir, main page]

Skip to content

Commit fa184fe

Browse files
committed
Fix n_components=None not being handled in MiniBatchDictionaryLearning.partial_fit
1 parent f91dee7 commit fa184fe

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

sklearn/decomposition/dict_learning.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,6 @@ def dict_learning(X, n_components, alpha, max_iter=100, tol=1e-8,
412412
SparsePCA
413413
MiniBatchSparsePCA
414414
"""
415-
416415
if method not in ('lars', 'cd'):
417416
raise ValueError('Coding method %r not supported as a fit algorithm.'
418417
% method)
@@ -604,6 +603,8 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
604603
MiniBatchSparsePCA
605604
606605
"""
606+
if n_components is None:
607+
n_components = X.shape[1]
607608

608609
if method not in ('lars', 'cd'):
609610
raise ValueError('Coding method not supported as a fit algorithm.')
@@ -750,7 +751,7 @@ def transform(self, X, y=None):
750751
Transformed data
751752
752753
"""
753-
check_is_fitted(self, 'components_')
754+
check_is_fitted(self, 'components_')
754755

755756
# XXX : kwargs is not documented
756757
X = check_array(X)
@@ -1159,13 +1160,9 @@ def fit(self, X, y=None):
11591160
"""
11601161
random_state = check_random_state(self.random_state)
11611162
X = check_array(X)
1162-
if self.n_components is None:
1163-
n_components = X.shape[1]
1164-
else:
1165-
n_components = self.n_components
11661163

11671164
U, (A, B), self.n_iter_ = dict_learning_online(
1168-
X, n_components, self.alpha,
1165+
X, self.n_components, self.alpha,
11691166
n_iter=self.n_iter, return_code=False,
11701167
method=self.fit_algorithm,
11711168
n_jobs=self.n_jobs, dict_init=self.dict_init,

0 commit comments

Comments
 (0)
0