8000 [MRG+2] Make CD use fused types by yenchenlin · Pull Request #6913 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+2] Make CD use fused types #6913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 50 commits into from
Aug 25, 2016

Conversation

yenchenlin
Copy link
Contributor
@yenchenlin yenchenlin commented Jun 21, 2016

According to #5464, current implementation of ElasticNet and Lasso in scikit-learn constrain the input to be np.float64, which is a waste of space.

This PR try to make CD algorithms support fused types when fitting np.float32 dense data and therefore reduce redundant data copy.

  • Make inline helper functions support fused types
  • Make dense CD ElasticNet support fused types
  • Add warning when alpha close to zero and X is np.float32
  • Add tests

**UPDATE 7/7

Here is the memory profiling results when fitting np.float32 data:

  • master

master

  • this branch

float32

**UPDATE 7/12

Here is the memory profiling results when fitting sparse np.float32 data:

  • master

64

  • this branch

32

@agramfort
Copy link
Member

do you expect to merge just this or is it wip?

@yenchenlin
Copy link
Contributor Author

@agramfort Yeah I am thinking of merging just these as a start point.

@yenchenlin yenchenlin changed the title Make helper functions in cd use fused types [MRG] Make helper functions in cd use fused types Jun 21, 2016
@jnothman
Copy link
Member
jnothman commented Jun 21, 2016

I'm not sure what value there is in merging it separately to something we can benchmark. For instance, you've fused fmax and fsign, but reused the libc fabs which explicitly operates on a double (as opposed to fabsl). Are you sure we benefit from fused implementations of max and sign?

So I think we want to review this as a whole.

@yenchenlin
Copy link
Contributor Author
yenchenlin commented Jun 23, 2016

Thanks @jnothman @agramfort .
Ah yeah you are right, I am working! 💪

@yenchenlin yenchenlin changed the title [MRG] Make helper functions in cd use fused types [WIP] Make CD use fused types Jun 23, 2016
@yenchenlin yenchenlin force-pushed the cd-fused-types branch 2 times, most recently from 0bab8d3 to bb6ec9b Compare July 1, 2016 23:49
@yenchenlin
Copy link
Contributor Author
yenchenlin commented Jul 1, 2016

***Updated 7/7

Currently it is still not working.
It is now working!

Here is my test script:

import numpy as np
from sklearn.linear_model.coordinate_descent import ElasticNet
from sys import argv

@profile
def fit_est():
    clf.fit(X, y)

np.random.seed(5)
X = np.random.rand(2000000, 40)
X = np.float32(X)
y = np.random.rand(2000000)
y = np.float32(y)
T = np.random.rand(5, 40)
T = np.float32(T)

clf = ElasticNet(alpha=1e-7, l1_ratio=1.0, precompute=False)
fit_est()
pred = clf.predict(T)
print pred

@yenchenlin yenchenlin force-pushed the cd-fused-types branch 4 times, most recently from 8f79269 to 61a7938 Compare July 2, 2016 14:10

# np.dot(R.T, y)
gap += (alpha * l1_norm - const * ddot(
if floating is double:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you absolutely certain we can't do this with fused types? What prohibits it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This algorithm uses lots of C pointer such as <DOUBLE*>, we can't do <floating*>.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're saying Cython disallows fused type pointers, or typecasts? Is that incapability documented?

Could we use typecasts if we were working with typed memoryviews?

@jnothman
Copy link
Member
jnothman commented Jul 4, 2016

Could you use the line-based memory profiling to see where that sharp increase in memory consumption is coming in?

@jnothman
Copy link
Member
jnothman commented Jul 4, 2016

Sorry, that was silly; only appropriate if the bad memory usage is in Python code.

fit_intercept and not np.allclose(X_offset, np.zeros(n_features)) or
normalize and not np.allclose(X_scale, np.ones(n_features))):
fit_intercept and not np.allclose(X_offset, np.zeros(n_features))
or normalize and not np.allclose(X_scale, np.ones(n_features))):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why this is an improvement, or why it's in this PR.

@yenchenlin yenchenlin force-pushed the cd-fused-types branch 2 times, most recently from 92e2aad to 82f6f9c Compare July 7, 2016 14:55
@yenchenlin yenchenlin changed the title [WIP] Make CD use fused types [WIP] Make dense CD use fused types Jul 7, 2016
@yenchenlin
Copy link
Contributor Author

Hello @jnothman & @MechCoder , thanks alot-alot-alot-alot for your patience and comments.

I've updated the PR description (including new memory profiling results), code, and addressed comments you gave before.

The remaining to-do tasks in my opinion are also listed in the main description of this PR.

@yenchenlin yenchenlin changed the title [WIP] Make dense CD use fused types [MRG] Make dense CD use fused types Jul 8, 2016
@yenchenlin
Copy link
Contributor Author
yenchenlin commented Jul 8, 2016

I've added the tests and the user warning for potential non-convergence error when fitting np.float32 data with small alpha.

However, the CI looks weird, any idea?

@yenchenlin
Copy link
Contributor Author

@jnothman done!

@jnothman
Copy link
Member

Have you forgotten to push those last changes? whats_new does not appear updated, nor the warning.

@@ -474,7 +474,8 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
warnings.warn('Objective did not converge.' +
' You might want' +
' to increase the number of iterations.' +
' Fitting data with alpha near zero, e.g., 1e-8,' +
' Fitting float32 data with alpha near zero,' +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this is only relevant if the data is float32, no?

Copy link
Contributor Author
@yenchenlin yenchenlin Aug 24, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think fitting with a really small alpha, e.g., 1e-20, even float64 data may not converge.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, so make the warning as relevant and useful as possible to a user that triggers it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is simply remove the float32 enough 😛?

Copy link
Member
@jnothman jnothman Aug 24, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, because alpha=1e-8 isn't ordinarily too small for normalized float64. Either remove reference to the alpha value or check appropriate conditions for the message, then it will be much more meaningful message. Also "alpha near zero" would usually be "very small alpha".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to admit that I'm not very sure about all the factors that will cause convergence issue, and thus not dare to determine a specific reference value.

Or we can remove reference to the alpha value?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove any specific value. Just say near zero

On 25 August 2016 at 23:34, Yen notifications@github.com wrote:

In sklearn/linear_model/coordinate_descent.py
#6913 (comment)
:

@@ -474,7 +474,8 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
warnings.warn('Objective did not converge.' +
' You might want' +
' to increase the number of iterations.' +

  •                      ' Fitting data with alpha near zero, e.g., 1e-8,' +
    
  •                      ' Fitting float32 data with alpha near zero,' +
    

I have to admit that I'm not very sure about all the factors that will
cause convergence issue, and thus not dare to determine a specific
reference value.

Or we can remove reference to the alpha value?


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
https://github.com/scikit-learn/scikit-learn/pull/6913/files/82fdf0962e3c7b0965b54ca137a56ab6d01fc226..c032d3b5820b53fb0717008435a44245cdb746f1#r76242986,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAEz65p1hO8jA-O7s1Ka5BIexYN_p7qlks5qjZnKgaJpZM4I6SId
.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@@ -470,7 +473,10 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
if dual_gap_ > eps_:
warnings.warn('Objective did not converge.' +
' You might want' +
' to increase the number of iterations',
' to increase the number of iterations.' +
' Fitting data with alpha near zero,' +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we not agree to add this message only when alpha is less than some heuristic value?

Copy link
Contributor Author
@yenchenlin yenchenlin Aug 25, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems hard to determine a heuristic value 😭

There are too many factors.

@MechCoder
Copy link
Member
MechCoder commented Aug 24, 2016

Your benchmarks in the description of the Pull Request suggests non-trivial speed gains. Do the speed gains also still hold?

@yenchenlin
Copy link
Contributor Author

@MechCoder yes!

  • float32
    32
  • float64
    64

@MechCoder
Copy link
Member

Awesome! Merging with master and thanks a lot for your perseverance.

@MechCoder MechCoder merged commit 084ef97 into scikit-learn:master Aug 25, 2016
@yenchenlin
Copy link
Contributor Author
yenchenlin commented Aug 25, 2016

😭😭 😭

🍻🍻🍻

@MechCoder
Copy link
Member

It should be worth adding a note to src/cblas/README.txt to let know what changes have to be made to add to call cblas functions internally. Maybe @fabianp can do that?

@jnothman
Copy link
Member

Hurrah! Thanks @fabianp for rescuing this. And to @MechCoder for inviting that saviour. And to @yenchenlin for winning.

@MechCoder
Copy link
Member
MechCoder commented Aug 26, 2016

And to you the dark knight? :p

@jnothman
Copy link
Member
jnothman commented Sep 6, 2016

Just a heads up that I'm a little concerned that these changes to Lasso changed its behaviour (for float64 data). It seems to have resulted in a test failure at #6717 (comment). For the dummy data I've tried so far, behaviour isn't changed, so this needs more verification.

TomDLT pushed a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016
…learn#6913)

ElasticNet and Lasso no longer implicitly convert float32 dtype input to float64 internally.

* Make helper functions in cd use fused types

* Import cblas float functions

* Make enet_coordinate_descent support fused types

* Make dense case work

* Refactor format

* Remove redundant change

* Add cblas files

* Avoid redundant code

* Remove redundant c files and import

* Recover unnecessary change

* Update comment

* Make coef_ type consistent

* Test float32 input

* Add user warning when fitting float32 data with small alpha

* Fix bug

* Change variable to floating type

* Make cd sparse support fused types

* Make CD support fused types when data is sparse

* Add referenced src files

* Avoid duplicated code

* Avoid type casting

* Fix indentation in test

* Avoid type casting in sparse implementation

* Fix indentation

* Fix duplicated intialization code

* Follow PEP8

* Raise tmp precision to double

* Add 64 bit computer check

* Fix test

* Add constraint

* PEP 8

* Make saxpy have the same structure as daxpy

Hopefully this fixes the problems outlined in PR scikit-learn#6913

* Remove wrong hardware test

* Remove dsdot

* Remove redundant asarray

* Add test for fit_intercept

* Make _preprocess_data support other dtypes

* Add concrete value

* Workaround

* Fix error msg

* Move declarartion

* Remove redundant comment

* Add tests

* Test normalize

* Delete warning

* Fix comment

* Add error msg

* Add error msg

* Add what's new

* Fix error msg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0