8000 [MRG]: MAINT center_data for linear models by giorgiop · Pull Request #5357 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG]: MAINT center_data for linear models #5357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

giorgiop
Copy link
Contributor
@giorgiop giorgiop commented Oct 7, 2015

Fixes #2601 no more.

TO DO

  • update the docstring saying that we do normalization, to reduce surprise
  • clean and factor center_data and sparse_center_data into a new private function, and deprecate them
  • add necessary deprecation, update docs
  • review all the # TODO and # XXX

AND NOT TO DO (see #2601 and discussion below):

3459513455_d1288a14b9_o

@giorgiop giorgiop force-pushed the center-data-normalization branch 2 times, most recently from 23273f5 to 876c8c8 Compare October 7, 2015 14:39
@giorgiop
Copy link
Contributor Author
giorgiop commented Oct 7, 2015

Regarding point (a). Edited from this test.

import numpy as np
from sklearn.linear_model import LinearRegression

rng = np.random.RandomState(0)

n_samples, n_features = (6, 5)
y = rng.randn(n_samples)
X = rng.randn(n_samples, n_features)
sample_weight = 1.0 + rng.rand(n_samples)

clf = LinearRegression()
clf.fit(X, y, sample_weight)
coefs1 = clf.coef_

# Sample weight can be implemented via a simple rescaling
# for the square loss.
scaled_y = y * np.sqrt(sample_weight)
scaled_X = X * np.sqrt(sample_weight)[:, np.newaxis]
clf.fit(scaled_X, scaled_y)
coefs2 = clf.coef_

print(coefs1)
[ 4.58791686 -4.2095038   0.39031788  3.2727146  -0.17386704]
print(coefs2)
[ 3.69763237 -3.64824351  0.367363    2.97550307 -0.44881672]

@giorgiop giorgiop force-pushed the center-data-normalization branch 3 times, most recently from 492c5cb to 73b30a5 Compare October 7, 2015 15:48
@MechCoder
Copy link
Member

Were you able to figure out the bugs in omp?

@giorgiop
Copy link
Contributor Author
giorgiop commented Oct 8, 2015

Which issue are you referring to?

@MechCoder
Copy link
Member

In the Travis build. We were trying to fix this issue before, but since omp uses center_data, there were some failures which react to the different scaling in X

@giorgiop giorgiop force-pushed the center-data-normalization branch 2 times, most recently from 682b219 to f1356ff Compare October 9, 2015 15:20
@giorgiop
Copy link
Contributor Author

Question: can we just implement the center_data function with preprocessing.scale ?

@giorgiop giorgiop force-pushed the center-data-normalization branch from 05a9d0e to 6e8290b Compare October 13, 2015 11:38
@giorgiop
Copy link
Contributor Author

Down to 81 errors.

@giorgiop
Copy link
Contributor Author

@amueller :

In which cases does the default behavior give a deprecation warning? We are not changing the default behavior, right? Only from normalize=False to standardize=False, so there shouldn't be deprecations warnings with default behavior if I read it correctly.

I found Lars and LassoLars with normalize=True by default. Not sure what's the best to do here.

@giorgiop giorgiop force-pushed the center-data-normalization branch 4 times, most recently from 7176457 to 610c21b Compare October 13, 2015 15:09
@giorgiop
Copy link
Contributor Author

Down to 35.

@amueller
Copy link
Member

If LassoLars has normalize=True, LassoLarsCV hopefully has normalize=True, too, right?
I don't have any better idea than biting the bullet and having the deprecation in Lars. @GaelVaroquaux @agramfort @jnothman any better ideas?

@giorgiop
Copy link
Contributor Author

If LassoLars has normalize=True, LassoLarsCV hopefully has normalize=True, too, right?

Yes, all classes in least_angle.py.

@giorgiop giorgiop force-pushed the center-data-normalization branch from 610c21b to c183cf2 Compare October 14, 2015 12:16
@amueller
Copy link
Member

@GaelVaroquaux as you are around, do you have a better idea than to warn by default in lars? I don't like it.

@GaelVaroquaux
Copy link
Member

@GaelVaroquaux as you are around, do you have a better idea than to warn by
default in lars? I don't like it.

What's the question? (I have a grant proposal to write before the sprint,
so I really shouldn't be spending time on scikit-learn. It's just more
fun than the grant proposal).

@amueller
Copy link
Member

We are deprecating normalize in all linear models, changing it to standardize that is more consistent. Unfortunately normalize=True in Lars, so that means Lars().fit(X, y) will give a deprecation warning, telling the user that behavior will change because normalize is deprecated. I don't like that but I don't see a better way.

@amueller
Copy link
Member

@GaelVaroquaux Also good luck :)

@GaelVaroquaux
Copy link
Member
GaelVaroquaux commented Oct 14, 2015 via email

@GaelVaroquaux
Copy link
Member

We are deprecating normalize in all linear models, changing it to standardize
that is more consistent.

I am mentionned in my comment on that issue, this is really a bad idea. I
spend an hour today with Gorgio and Olivier to convince myself of this.

The whole value of this option is that it makes linear models robust in
their parameter setting, unlike chaining with a StandardScaler. If people
want to standardize, they should use StandardScaler. But it will make
their parameters fragile to sample size.

We convinced ourselves that we shouldn't do this change and keep the
current behavior of scikit-learn.

I am sorry, I said 👍 to this change a few days ago. But Gorgio hit
snags while implementing it that forced us to rethink, and that was
actually a good thing.

@agramfort
Copy link
Member
agramfort commented Oct 14, 2015 via email

sample_weight = np.sqrt(sample_weight)
sw_matrix = sparse.dia_matrix((sample_weight, 0),
shape=(n_samples, n_samples))
sw_matrix = np.diag(sample_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will blow up the memory if you have large number of samples.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, what was the problem with the existing implementation in master?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry my bad, I don't remember what I was thinking at that time.

@agramfort
Copy link
Member

that's it for me now. Sorry for the super slow reaction time...

ping when I shall take a look again.

thanks @giorgiop !

@giorgiop giorgiop force-pushed the center-data-normalization branch from c29c00b to ea666c5 Compare February 8, 2016 23:58
@giorgiop
Copy link
Contributor Author
giorgiop commented Feb 9, 2016

@agramfort I finally had time to resume this one :)
Do I need to worry about CI complaining for decreased coverage?

@agramfort
Copy link
Member

looks ok at first sight.

maybe @ogrisel has some time to have a look?

@@ -360,6 +425,13 @@ class LinearRegression(LinearModel, RegressorMixin):

normalize : boolean, optional, default False
If True, the regressors X will be normalized before regression.
When the regressors are normalized, the fitted `coef_` are the same
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand this. Did you mean the "fitted coef_ are of the same scale"?. How will it be same?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this may not be the best wording. The discussion is in #5447 with @GaelVaroquaux. Let me know if you guys have a better and concise way to explain this property here.

@giorgiop giorgiop force-pushed the center-data-normalization branch from effdf97 to c822947 Compare February 12, 2016 00:24
@@ -360,6 +426,14 @@ class LinearRegression(LinearModel, RegressorMixin):

normalize : boolean, optional, default False
If True, the regressors X will be normalized before regression.
This parameter is ignored when `fit_intercept` is set to `False`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. This is what I infer, when X is scaled down to unit variance, the squared loss term goes up by a certain factor (I'm not sure we can particularly determine this because y is not normalized), hence for different number of samples, alpha has to be set differently to achieve the same effect. Even when X is scaled to unit length, there is this effect but to a much lesser extent.

Should we just write something like "Setting normalized equals to true scales down the input regressors to unit length. Note that this makes the hyperparameters learnt more robust and almost independent of the number of samples. The same property contd ... "?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To @GaelVaroquaux the last word here!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with current formulation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just that the "fiited coef_ are the same" term that I am worried about.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed. "same" is not correct.

+1 for

Note that this makes the hyperparameters learnt more robust and almost independent of the number of samples.

thx @MechCoder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK!

@MechCoder
Copy link
Member

Just left 2 comments.

# inplace_csr_row_normalize_l2 must be changed such that it
# can return also the norms computed internally

# transform variance to norm in-place
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use csr_row_norms which already does this, but we can leave that for later.

@giorgiop giorgiop force-pushed the center-data-normalization branch 2 times, most recently from 43d0f2c to d40f1f1 Compare February 15, 2016 03:28
@MechCoder
Copy link
Member

@giorgiop Can you fix the "fitted coef_" are same comment for all estimators? After that we can merge.

@giorgiop giorgiop force-pushed the center-data-normalization branch from d40f1f1 to ec0a97a Compare February 15, 2016 21:24
@giorgiop giorgiop force-pushed the center-data-normalization branch from ec0a97a to acdd1aa Compare February 15, 2016 21:27
@giorgiop
Copy link
Contributor Author

Done !

@MechCoder
Copy link
Member

Merged @giorgiop

@MechCoder MechCoder closed this Feb 17, 2016
@agramfort
Copy link
Member

thanks heaps @giorgiop ! 🍻

@giorgiop
Copy link
Contributor Author

Thanks to all reviewers for the combined effort!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0