-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG+1] Removing repeated input checking in Lasso and DictLearning #5133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,10 +107,10 @@ def _sparse_encode(X, dictionary, gram, cov=None, algorithm='lasso_lars', | |
|
||
elif algorithm == 'lasso_cd': | ||
alpha = float(regularization) / n_features # account for scaling | ||
clf = Lasso(alpha=alpha, fit_intercept=False, precompute=gram, | ||
max_iter=max_iter, warm_start=True) | ||
clf = Lasso(alpha=alpha, fit_intercept=False, normalize=False, | ||
precompute=gram, max_iter=max_iter, warm_start=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes me realise that we always use Lasso with precomputed Gram matrix in the context of online dictionary learning. Have you checked if this is actually a always a good idea from a performance point of view? If instead we want to use allow the use of coordinate descent with row slices (sample-wise minibatch) of a Fortran aligned feature array then we would need to change the cython code of coordinate descent to deal properly with strided arrays (views) in the This does not seem to be hard but maybe this should be investigated in a separate PR. |
||
clf.coef_ = init | ||
clf.fit(dictionary.T, X.T) | ||
clf.fit(dictionary.T, X.T, check_input=False) | ||
new_code = clf.coef_ | ||
|
||
elif algorithm == 'lars': | ||
|
@@ -224,8 +224,10 @@ def sparse_encode(X, dictionary, gram=None, cov=None, algorithm='lasso_lars', | |
n_components = dictionary.shape[0] | ||
|
||
if gram is None and algorithm != 'threshold': | ||
gram = np.dot(dictionary, dictionary.T) | ||
if cov is None: | ||
# Transposing product to ensure Fortran ordering | ||
gram = np.dot(dictionary, dictionary.T).T | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This ensure that Gram matrix is in Fortran order There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add an inline comment to make that optim explicit. |
||
|
||
if cov is None and algorithm != 'lasso_cd': | ||
copy_cov = False | ||
cov = np.dot(dictionary, X.T) | ||
|
||
|
@@ -239,18 +241,27 @@ def sparse_encode(X, dictionary, gram=None, cov=None, algorithm='lasso_lars', | |
regularization = 1. | ||
|
||
if n_jobs == 1 or algorithm == 'threshold': | ||
return _sparse_encode(X, dictionary, gram, cov=cov, | ||
code = _sparse_encode(X, | ||
dictionary, gram, cov=cov, | ||
algorithm=algorithm, | ||
regularization=regularization, copy_cov=copy_cov, | ||
init=init, max_iter=max_iter) | ||
init=init, | ||
max_iter=max_iter) | ||
# This ensure that dimensionality of code is always 2, | ||
# consistant with the case n_jobs > 1 | ||
if code.ndim == 1: | ||
code = code[np.newaxis, :] | ||
return code | ||
|
||
# Enter parallel code block | ||
code = np.empty((n_samples, n_components)) | ||
slices = list(gen_even_slices(n_samples, _get_n_jobs(n_jobs))) | ||
|
||
code_views = Parallel(n_jobs=n_jobs)( | ||
delayed(_sparse_encode)( | ||
X[this_slice], dictionary, gram, cov[:, this_slice], algorithm, | ||
X[this_slice], dictionary, gram, | ||
cov[:, this_slice] if cov is not None else None, | ||
algorithm, | ||
regularization=regularization, copy_cov=copy_cov, | ||
init=init[this_slice] if init is not None else None, | ||
max_iter=max_iter) | ||
|
@@ -639,7 +650,6 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100, | |
else: | ||
dictionary = np.r_[dictionary, | ||
np.zeros((n_components - r, dictionary.shape[1]))] | ||
dictionary = np.ascontiguousarray(dictionary.T) | ||
|
||
if verbose == 1: | ||
print('[dict_learning]', end=' ') | ||
|
@@ -650,6 +660,10 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100, | |
else: | ||
X_train = X | ||
|
||
dictionary = check_array(dictionary.T, order='F', dtype=np.float64, | ||
copy=False) | ||
X_train = check_array(X_train, order='C', dtype=np.float64, copy=False) | ||
|
||
batches = gen_batches(n_samples, batch_size) | ||
batches = itertools.cycle(batches) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from sklearn.utils.testing import assert_greater | ||
from sklearn.utils.testing import assert_raises | ||
from sklearn.utils.testing import assert_warns | ||
from sklearn.utils.testing import assert_warns_message | ||
from sklearn.utils.testing import ignore_warnings | ||
from sklearn.utils.testing import assert_array_equal | ||
from sklearn.utils.testing import TempMemmap | ||
|
@@ -25,6 +26,7 @@ | |
LassoCV, ElasticNet, ElasticNetCV, MultiTaskLasso, MultiTaskElasticNet, \ | ||
MultiTaskElasticNetCV, MultiTaskLassoCV, lasso_path, enet_path | ||
from sklearn.linear_model import LassoLarsCV, lars_path | ||
from sklearn.utils import check_array | ||
|
||
|
||
def check_warnings(): | ||
|
@@ -628,3 +630,39 @@ def test_sparse_dense_descent_paths(): | |
_, coefs, _ = path(X, y, fit_intercept=False) | ||
_, sparse_coefs, _ = path(csr, y, fit_intercept=False) | ||
assert_array_almost_equal(coefs, sparse_coefs) | ||
|
||
|
||
def test_check_input_false(): | ||
X, y, _, _ = build_dataset(n_samples=20, n_features=10) | ||
X = check_array(X, order='F', dtype='float64') | ||
y = check_array(X, order='F', dtype='float64') | ||
clf = ElasticNet(selection='cyclic', tol=1e-8) | ||
# Check that no error is raised if data is provided in the right format | ||
clf.fit(X, y, check_input=False) | ||
X = check_array(X, order='F', dtype='float32') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess that if you try with the correct dtype but the wrong memory layout: X = check_array(X, order='C', dtype='float64') you should not get an error but a wrong results. This is expected but maybe this should be asserted in the tests as well. |
||
clf.fit(X, y, check_input=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I did not read this line and the following line correctly... |
||
# Check that an error is raised if data is provided in the wrong format, | ||
# because of check bypassing | ||
assert_raises(ValueError, clf.fit, X, y, check_input=False) | ||
|
||
# With no input checking, providing X in C order should result in false | ||
# computation | ||
X = check_array(X, order='C', dtype='float64') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe insert an inline comment before this line to explain the following assertions. |
||
clf.fit(X, y, check_input=False) | ||
coef_false = clf.coef_ | ||
clf.fit(X, y, check_input=True) | ||
coef_true = clf.coef_ | ||
assert_raises(AssertionError, assert_array_almost_equal, | ||
coef_true, coef_false) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extra line. |
||
def test_overrided_gram_matrix(): | ||
X, y, _, _ = build_dataset(n_samples=20, n_features=10) | ||
Gram = X.T.dot(X) | ||
clf = ElasticNet(selection='cyclic', tol=1e-8, precompute=Gram, | ||
fit_intercept=True) | ||
assert_warns_message(UserWarning, | ||
"Gram matrix was provided but X was centered" | ||
" to fit intercept, " | ||
"or X was normalized : recomputing Gram matrix.", | ||
clf.fit, X, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be three lines only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adressed