-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
EFF Speed-up MiniBatchDictionaryLearning by avoiding multiple validation #25493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
67ea69d
850bce7
a7e05fe
9915fc1
fd88b7b
f070f21
8ce15ba
35d4f1f
3f6c2b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
from ..utils.validation import check_is_fitted | ||
from ..utils.parallel import delayed, Parallel | ||
from ..linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars | ||
from .._config import config_context | ||
|
||
|
||
def _check_positive_coding(method, positive): | ||
|
@@ -2381,9 +2382,10 @@ def fit(self, X, y=None): | |
for i, batch in zip(range(n_steps), batches): | ||
X_batch = X_train[batch] | ||
|
||
batch_cost = self._minibatch_step( | ||
X_batch, dictionary, self._random_state, i | ||
) | ||
with config_context(assume_finite=True, skip_parameter_validation=True): | ||
batch_cost = self._minibatch_step( | ||
X_batch, dictionary, self._random_state, i | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #25490 only gets rid of 1 layer of validation (there are 4 validations in total) so even if it gets merged, we'll still need the context manager around the |
||
|
||
if self._check_convergence( | ||
X_batch, batch_cost, dictionary, old_dict, n_samples, i, n_steps | ||
|
@@ -2463,7 +2465,8 @@ def partial_fit(self, X, y=None): | |
else: | ||
dictionary = self.components_ | ||
|
||
self._minibatch_step(X, dictionary, self._random_state, self.n_steps_) | ||
with config_context(assume_finite=True, skip_parameter_validation=True): | ||
self._minibatch_step(X, dictionary, self._random_state, self.n_steps_) | ||
|
||
self.components_ = dictionary | ||
self.n_steps_ += 1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's important to state that
check_array
still runs: