8000 [MRG+1] Removing repeated input checking in Lasso and DictLearning by arthurmensch · Pull Request #5133 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Removing repeated input checking in Lasso and DictLearning #5133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 20, 2015

Conversation

arthurmensch
Copy link
Contributor

I use :

  • a check_input flag in coordinate_descent.enet_path, that is set to False when called from ElasticNet.fit
    • a check_input flag in coordinate_descent.ElasticNet.fit, that is set to False when called from sparse_encode

I changed the condition for overriding provided Gram Matrix in linear_model.base._pre_fit, in order not to do a pass on data when fit_intercept and normalize are set to False (in master, we override Gram matrix if we find that X was changed by centering and rescaling, even when these are disabled, which wastes computation).

I also avoid computing cov in sparse_encode when using coordiante_descent, as cov computation is done within the Lasso class.

On provided plot_online_sparse_pca example, we gain a factor 2 in performance.

@@ -421,12 +426,12 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy):
precompute = (n_samples > n_features)

if precompute is True:
precompute = np.dot(X.T, X)
precompute = np.dot(X.T, X).T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same rmk here

@agramfort
Copy link
Member

do we really want to add this example?

travis is not happy.

@@ -0,0 +1,104 @@
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arthurmensch
Copy link
Contributor Author

This example has no point I only use it for benchmark. I am putting into a seperate gist

@arthurmensch arthurmensch changed the title Removing repeated input checking in Lasso and DictLearning [WIP] Removing repeated input checking in Lasso and DictLearning Aug 18, 2015
@arthurmensch
Copy link
Contributor Author

I added a pool context manager in dict_learning.py (should probably put into another PR, although this one actually aims at allowing multiprocessing). Here is what I find :

  • In all cases, it is faster to use the context manager (that was expected)
  • Working on small X / small y (i.e small batches), it is faster to use multiprocessing backend, due to the fact that Lasso.fit has only a small part (~20%) of its computation performed with GIL released. Using threading causes too much congestion, with smaller pool management (1 ms for n_jobs = 2) overhead than multiprocessing (20 ms for n_jobs = 1). In this regime however, using n_jobs > 1 results at best in the same computation time as n_jobs = 1. In this case, I think we should not allow multiprocessing.
  • Working on larger X, y, it is faster to use threading backend, due to the fact that the cython part has more importance (~60%). Overhead is the same as above, but becomes negligible.

On a side note, I encountered problem with the context manger in multiprocessing mode with large data (see joblib/joblib#229)

We obtain small improvement working on samples of dimension 409 600 and batch_size = 400, as sparse_encode takes 25% less time. However, in this regime, _update_dict and auxiliary variable update becomes more important, which mitigates the improvement : around 10 % using 2 cores, and not significant imrpovement with more cores.

We should thus allow multiprocessing only past a certain threshold of batch_size * n_features in dictionary learning, if we want to allow it at all. I need precise benchmarking to set this threshold.

@@ -397,7 +398,8 @@ def fit(self, X, y):
return self


def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy):
def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy,
aux_order=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what "aux" stands for. Is it an acronym? A more explicit is probably required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically this flag ask for Fortran order Xy and precompute. auxiliary_matrix_order ?

@ogrisel
Copy link
Member
ogrisel commented Aug 19, 2015

In order to move forward, I think this we should focus on fixing this PR (address the remaining comments and make the tests pass) and move the introduction a managed parallel context to a later PR.

@@ -585,7 +599,7 @@ class ElasticNet(LinearModel, RegressorMixin):
def __init__(self, alpha=1.0, l1_ratio=0.5, fit_intercept=True,
normalize=False, precompute=False, max_iter=1000,
copy_X=True, tol=1e-4, warm_start=False, positive=False,
random_state=None, selection='cyclic'):
random_state=None, selection='cyclic', check_input=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other scikit-learn models use check_input only as per-method parameters rather than global constructor parameter when used in estimators. I think we should stay consistent with that pattern.

@arthurmensch
Copy link
Contributor Author

OK I will move multiprocessing to a later PR. For the record, here are the benchmarks I did with multiprocessing.
results

Without logscale

results_linear

For very large sample we begin to see a small improvement (especially with 5 cores). Far from being outstanding though

@arthurmensch
Copy link
Contributor Author

Comparison with master for 1 core, bypassing checks : green is master, blue is this PR.

checks

Average improvement for a single iteration : x 3,47.

Time before the first iteration is a little longer as we perform initial input checking within dictionary_learning_online, but this is negligible compared to single iteration improvement.

@arthurmensch arthurmensch changed the title [WIP] Removing repeated input checking in Lasso and DictLearning [MRG] Removing repeated input checking in Lasso and DictLearning Aug 19, 2015
@arthurmensch
8000
Copy link
Contributor Author

I will wait for reviews before opening the multiprocessing PR. Travis does not complain anymore.

I have not look into lars code (which is way slower for dictionary learning in scikit-learn, especially with this improvement), but it might be worth it to do the same changes.

clf = ElasticNet(selection='cyclic', tol=1e-8)
# Check that no error is raised if data is provided in the right format
clf.fit(X, y, check_input=False)
X = check_array(X, order='F', dtype='float32')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that if you try with the correct dtype but the wrong memory layout:

X = check_array(X, order='C', dtype='float64')

you should not get an error but a wrong results. This is expected but maybe this should be asserted in the tests as well.

@ogrisel
Copy link
Member
ogrisel commented Aug 19, 2015

This comment has not been addressed: https://github.com/scikit-learn/scikit-learn/pull/5133/files#r37388137

clf = Lasso(alpha=alpha, fit_intercept=False, precompute=gram,
max_iter=max_iter, warm_start=True)
clf = Lasso(alpha=alpha, fit_intercept=False, normalize=False,
precompute=gram, max_iter=max_iter, warm_start=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me realise that we always use Lasso with precomputed Gram matrix in the context of online dictionary learning. Have you checked if this is actually a always a good idea from a performance point of view?

If instead we want to use allow the use of coordinate descent with row slices (sample-wise minibatch) of a Fortran aligned feature array then we would need to change the cython code of coordinate descent to deal properly with strided arrays (views) in the daxpy calls if we want to leverage the check_input=False code path safely.

This does not seem to be hard but maybe this should be investigated in a separate PR.

@ogrisel
Copy link
Member
ogrisel commented Aug 19, 2015

This looks good. Please add an entry in doc/whats_new.rst and mention the 3x speedup you measured.

@ogrisel ogrisel changed the title [MRG] Removing repeated input checking in Lasso and DictLearning [MRG+1] Removing repeated input checking in Lasso and DictLearning Aug 19, 2015
@ogrisel
Copy link
Member
ogrisel commented Aug 19, 2015

Also please squash this PR into a single commit.

@@ -77,7 +78,7 @@ def sparse_center_data(X, y, fit_intercept, normalize=False):
return X, y, X_mean, y_mean, X_std


def center_data(X, y, fit_intercept, normalize=False, copy=True,
def center_data(X, y, fit_intercept, normalize=False, copy=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oups. Extra space

coef_false = clf.coef_
clf.fit(X, y, check_input=True)
coef_true = clf.coef_
assert_true(np.any(coef_true != coef_false))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use assert_array_almost_equal

@agramfort
Copy link
Member

besides LGTM if travis is happy.

@arthurmensch
Copy link
Contributor Author

Done.

A93C
ogrisel added a commit that referenced this pull request Aug 20, 2015
[MRG+1] Removing repeated input checking in Lasso and DictLearning
@ogrisel ogrisel merged commit 393d651 into scikit-learn:master Aug 20, 2015
@ogrisel
Copy link
Member
ogrisel commented Aug 20, 2015

Great, thank for the optim @arthurmensch!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0