10000 [MRG]: MAINT center_data for linear models by giorgiop · Pull Request #5357 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG]: MAINT center_data for linear models #5357

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/datasets/rcv1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Each sample can be identified by its ID, ranging (with gaps) from 2286 to 810596
array([2286, 2287, 2288], dtype=int32)

``target_names``:
The target values are the topics of each sample. Each sample belongs to at least one topic, and to up to 17 topics.
The target values are the topics of each sample. Each sample belongs to at least one topic, and to up to 17 topics.
There are 103 topics, each represented by a string. Their corpus frequencies span five orders of magnitude, from 5 occurrences for 'GMIL', to 381327 for 'CCAT'::

>>> rcv1.target_names[:3].tolist() # doctest: +SKIP
Expand Down
162 changes: 118 additions & 44 deletions sklearn/linear_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Mathieu Blondel <mathieu@mblondel.org>
# Lars Buitinck <L.J.Buitinck@uva.nl>
# Maryan Morel <maryan.morel@polytechnique.edu>
#
# Giorgio Patrini <giorgio.patrini@anu.edu.au>
# License: BSD 3 clause

from __future__ import division
Expand All @@ -26,20 +26,16 @@
from ..externals import six
from ..externals.joblib import Parallel, delayed
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
from ..utils import as_float_array, check_array, check_X_y, deprecated
from ..utils import check_random_state, column_or_1d
from ..utils import check_array, check_X_y, deprecated, as_float_array
from ..utils.validation import FLOAT_DTYPES
from ..utils import check_random_state
from ..utils.extmath import safe_sparse_dot
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
from ..utils.fixes import sparse_lsqr
from ..utils.seq_dataset import ArrayDataset, CSRDataset
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError


#
# TODO: intercept for all models
# We should define a common function to center data instead of
# repeating the same code inside each fit method.
from ..preprocessing.data import normalize as f_normalize

# TODO: bayesian_ridge_regression and bayesian_regression_ard
# should be squashed into its respective objects.
Expand Down Expand Up @@ -71,6 +67,8 @@ def make_dataset(X, y, sample_weight, random_state=None):
return dataset, intercept_decay


@deprecated("sparse_center_data will be removed in "
"0.20. Use utilities in preprocessing.data instead")
def sparse_center_data(X, y, fit_intercept, normalize=False):
"""
Compute information needed to center data to have mean zero along
Expand All @@ -87,33 +85,33 @@ def sparse_center_data(X, y, fit_intercept, normalize=False):
else:
X = sp.csc_matrix(X, copy=normalize, dtype=np.float64)

X_mean, X_var = mean_variance_axis(X, axis=0)
X_offset, X_var = mean_variance_axis(X, axis=0)
if normalize:
# transform variance to std in-place
# XXX: currently scaled to variance=n_samples to match center_data
X_var *= X.shape[0]
X_std = np.sqrt(X_var, X_var)
del X_var
X_std[X_std == 0] = 1
inplace_column_scale(X, 1. / X_std)
else:
X_std = np.ones(X.shape[1])
y_mean = y.mean(axis=0)
y = y - y_mean
y_offset = y.mean(axis=0)
y = y - y_offset
else:
X_mean = np.zeros(X.shape[1])
X_offset = np.zeros(X.shape[1])
X_std = np.ones(X.shape[1])
y_mean = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
y_offset = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)

return X, y, X_mean, y_mean, X_std
return X, y, X_offset, y_offset, X_std


@deprecated("center_data will be removed in "
"0.20. Use utilities in preprocessing.data instead")
def center_data(X, y, fit_intercept, normalize=False, copy=True,
sample_weight=None):
"""
Centers data to have mean zero along axis 0. This is here because
nearly all linear models will want their data to be centered.

If sample_weight is not None, then the weighted mean of X and y
is zero, and not the mean itself
"""
Expand All @@ -122,26 +120,95 @@ def center_data(X, y, fit_intercept, normalize=False, copy=True,
if isinstance(sample_weight, numbers.Number):
sample_weight = None
if sp.issparse(X):
X_mean = np.zeros(X.shape[1])
X_offset = np.zeros(X.shape[1])
X_std = np.ones(X.shape[1])
else:
X_mean = np.average(X, axis=0, weights=sample_weight)
X -= X_mean
X_offset = np.average(X, axis=0, weights=sample_weight)
X -= X_offset
# XXX: currently scaled to variance=n_samples
if normalize:
# XXX: currently scaled to variance=n_samples
X_std = np.sqrt(np.sum(X ** 2, axis=0))
X_std[X_std == 0] = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep the XXX. It can die along with the function in 0.20 ;)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I would just a comment saying that we are scaling to unit length

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both the comments?

X /= X_std
else:
X_std = np.ones(X.shape[1])
y_mean = np.average(y, axis=0, weights=sample_weight)
y = y - y_mean
y_offset = np.average(y, axis=0, weights=sample_weight)
y = y - y_offset
else:
X_mean = np.zeros(X.shape[1])
X_offset = np.zeros(X.shape[1])
X_std = np.ones(X.shape[1])
y_mean = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
return X, y, X_mean, y_mean, X_std
y_offset = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
return X, y, X_offset, y_offset, X_std


def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
sample_weight=None, return_mean=False):
"""
Centers data to have mean zero along axis 0. If fit_intercept=False or if
the X is a sparse matrix, no centering is done, but normalization can still
be applied. The function returns the statistics necessary to reconstruct
the input data, which are X_offset, y_offset, X_scale, such that the output

X = (X - X_offset) / X_scale
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should mention that X_scale here is the L2 norm of X - X_offset


X_scale is the L2 norm of X - X_offset. If sample_weight is not None,
then the weighted mean of X and y is zero, and not the mean itself. If
return_mean=True, the mean, eventually weighted, is returned, independently
of whether X was centered (option used for optimization with sparse data in
coordinate_descend).

This is here because nearly all linear models will want their data to be
centered.
"""

if isinstance(sample_weight, numbers.Number):
sample_weight = None

X = check_array(X, copy=copy, accept_sparse=['csr', 'csc'],
dtype=FLOAT_DTYPES)

if fit_intercept:
if sp.issparse(X):
X_offset, X_var = mean_variance_axis(X, axis=0)
if not return_mean:
X_offset = np.zeros(X.shape[1])

if normalize:

# TODO: f_normalize could be used here as well but the function
# inplace_csr_row_normalize_l2 must be changed such that it
# can return also the norms computed internally

# transform variance to norm in-place
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use csr_row_norms which already does this, but we can leave that for later.

X_var *= X.shape[0]
X_scale = np.sqrt(X_var, X_var)
del X_var
X_scale[X_scale == 0] = 1
inplace_column_scale(X, 1. / X_scale)
else:
X_scale = np.ones(X.shape[1])

else:
X_offset = np.average(X, axis=0, weights=sample_weight)
X -= X_offset
if normalize:
X, X_scale = f_normalize(X, axis=0, copy=False,
return_norm=True)
else:
X_scale = np.ones(X.shape[1])
y_offset = np.average(y, axis=0, weights=sample_weight)
y = y - y_offset
else:
X_offset = np.zeros(X.shape[1])
X_scale = np.ones(X.shape[1])
y_offset = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)

return X, y, X_offset, y_offset, X_scale


# TODO: _rescale_data should be factored into _preprocess_data.
# Currently, the fact that sag implements its own way to deal with
# sample_weight makes the refactoring tricky.

def _rescale_data(X, y, sample_weight):
"""Rescale data so as to support sample_weight"""
Expand Down Expand Up @@ -200,14 +267,14 @@ def predict(self, X):
"""
return self._decision_function(X)

_center_data = staticmethod(center_data)
_preprocess_data = staticmethod(_preprocess_data)

def _set_intercept(self, X_mean, y_mean, X_std):
def _set_intercept(self, X_offset, y_offset, X_scale):
"""Set the intercept_
"""
if self.fit_intercept:
self.coef_ = self.coef_ / X_std
self.intercept_ = y_mean - np.dot(X_mean, self.coef_.T)
self.coef_ = self.coef_ / X_scale
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remind me why we do this? I have a thought in mind but I do not want the answer to be biased by my thoughts :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you fit the intercept by centering the data. Then you undo the centering to have the intercept of the original pb.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the intercept, I was meaning the divide by X_scale, aren't we already scaling X by X_scale?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but here you scale the coef_

something like lhis:

X * scaling * (1/scaling) * coef_

self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
else:
self.intercept_ = 0.

Expand Down Expand Up @@ -360,6 +427,13 @@ class LinearRegression(LinearModel, RegressorMixin):

normalize : boolean, optional, default False
If True, the regressors X will be normalized before regression.
This parameter is ignored when `fit_intercept` is set to False.
When the regressors are normalized, note that this makes the
hyperparameters learnt more robust and almost independent of the number
of samples. The same property is not valid for standardized data.
However, if you wish to standardize, please use
`preprocessing.StandardScaler` before calling `fit` on an estimator
with `normalize=False`.

copy_X : boolean, optional, default True
If True, X will be copied; else, it may be overwritten.
Expand Down Expand Up @@ -435,13 +509,12 @@ def fit(self, X, y, sample_weight=None):
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'],
y_numeric=True, multi_output=True)

if ((sample_weight is not None) and np.atleast_1d(
sample_weight).ndim > 1):
sample_weight = column_or_1d(sample_weight, warn=True)
if sample_weight is not None and np.atleast_1d(sample_weight).ndim > 1:
raise ValueError("Sample weights must be 1D array or scalar")

X, y, X_mean, y_mean, X_std = self._center_data(
X, y, self.fit_intercept, self.normalize, self.copy_X,
sample_weight=sample_weight)
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
X, y, fit_intercept=self.fit_intercept, normalize=self.normalize,
copy=self.copy_X, sample_weight=sample_weight)

if sample_weight is not None:
# Sample weight can be implemented via a simple rescaling.
Expand All @@ -466,7 +539,7 @@ def fit(self, X, y, sample_weight=None):

if y.ndim == 1:
self.coef_ = np.ravel(self.coef_)
self._set_intercept(X_mean, y_mean, X_std)
self._set_intercept(X_offset, y_offset, X_scale)
return self


Expand All @@ -476,15 +549,16 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy):

if sparse.isspmatrix(X):
precompute = False
X, y, X_mean, y_mean, X_std = sparse_center_data(
X, y, fit_intercept, normalize)
X, y, X_offset, y_offset, X_scale = _preprocess_data(
X, y, fit_intercept=fit_intercept, normalize=normalize,
return_mean=True)
else:
# copy was done in fit if necessary
X, y, X_mean, y_mean, X_std = center_data(
X, y, fit_intercept, normalize, copy=copy)
X, y, X_offset, y_offset, X_scale = _preprocess_data(
X, y, fit_intercept=fit_intercept, normalize=normalize, copy=copy)
if hasattr(precompute, '__array__') and (
fit_intercept and not np.allclose(X_mean, np.zeros(n_features))
or normalize and not np.allclose(X_std, np.ones(n_features))):
fit_intercept and not np.allclose(X_offset, np.zeros(n_features)) or
normalize and not np.allclose(X_scale, np.ones(n_features))):
warnings.warn("Gram matrix was provided but X was centered"
" to fit intercept, "
"or X was normalized : recomputing Gram matrix.",
Expand Down Expand Up @@ -521,4 +595,4 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy):
order='F')
np.dot(y.T, X, out=Xy.T)

return X, y, X_mean, y_mean, X_std, precompute, Xy
return X, y, X_offset, y_offset, X_scale, precompute, Xy
Loading
0