8000 MAINT depr of center_data, normalize in linear_model · scikit-learn/scikit-learn@acdd1aa · GitHub
[go: up one dir, main page]

Skip to content

Commit acdd1aa

Browse files
author
giorgiop
committed
MAINT depr of center_data, normalize in linear_model
1 parent 2270af2 commit acdd1aa

File tree

12 files changed

+605
-270
lines changed

12 files changed

+605
-270
lines changed

doc/datasets/rcv1.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Each sample can be identified by its ID, ranging (with gaps) from 2286 to 810596
3535
array([2286, 2287, 2288], dtype=int32)
3636

3737
``target_names``:
38-
The target values are the topics of each sample. Each sample belongs to at least one topic, and to up to 17 topics.
38+
The target values are the topics of each sample. Each sample belongs to at least one topic, and to up to 17 topics.
3939
There are 103 topics, each represented by a string. Their corpus frequencies span five orders of magnitude, from 5 occurrences for 'GMIL', to 381327 for 'CCAT'::
4040

4141
>>> rcv1.target_names[:3].tolist() # doctest: +SKIP

sklearn/linear_model/base.py

Lines changed: 118 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# Mathieu Blondel <mathieu@mblondel.org>
1111
# Lars Buitinck <L.J.Buitinck@uva.nl>
1212
# Maryan Morel <maryan.morel@polytechnique.edu>
13-
#
13+
# Giorgio Patrini <giorgio.patrini@anu.edu.au>
1414
# License: BSD 3 clause
1515

1616
from __future__ import division
@@ -26,20 +26,16 @@
2626
from ..externals import six
2727
from ..externals.joblib import Parallel, delayed
2828
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
29-
from ..utils import as_float_array, check_array, check_X_y, deprecated
30-
from ..utils import check_random_state, column_or_1d
29+
from ..utils import check_array, check_X_y, deprecated, as_float_array
30+
from ..utils.validation import FLOAT_DTYPES
31+
from ..utils import check_random_state
3132
from ..utils.extmath import safe_sparse_dot
3233
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
3334
from ..utils.fixes import sparse_lsqr
3435
from ..utils.seq_dataset import ArrayDataset, CSRDataset
3536
from ..utils.validation import check_is_fitted
3637
from ..exceptions import NotFittedError
37-
38-
39-
#
40-
# TODO: intercept for all models
41-
# We should define a common function to center data instead of
42-
# repeating the same code inside each fit method.
38+
from ..preprocessing.data import normalize as f_normalize
4339

4440
# TODO: bayesian_ridge_regression and bayesian_regression_ard
4541
# should be squashed into its respective objects.
@@ -71,6 +67,8 @@ def make_dataset(X, y, sample_weight, random_state=None):
7167
return dataset, intercept_decay
7268

7369

70+
@deprecated("sparse_center_data will be removed in "
71+
"0.20. Use utilities in preprocessing.data instead")
7472
def sparse_center_data(X, y, fit_intercept, normalize=False):
7573
"""
7674
Compute information needed to center data to have mean zero along
@@ -87,33 +85,33 @@ def sparse_center_data(X, y, fit_intercept, normalize=False):
8785
else:
8886
X = sp.csc_matrix(X, copy=normalize, dtype=np.float64)
8987

90-
X_mean, X_var = mean_variance_axis(X, axis=0)
88+
X_offset, X_var = mean_variance_axis(X, axis=0)
9189
if normalize:
9290
# transform variance to std in-place
93-
# XXX: currently scaled to variance=n_samples to match center_data
9491
X_var *= X.shape[0]
9592
X_std = np.sqrt(X_var, X_var)
9693
del X_var
9794
X_std[X_std == 0] = 1
9895
inplace_column_scale(X, 1. / X_std)
9996
else:
10097
X_std = np.ones(X.shape[1])
101-
y_mean = y.mean(axis=0)
102-
y = y - y_mean
98+
y_offset = y.mean(axis=0)
99+
y = y - y_offset
103100
else:
104-
X_mean = np.zeros(X.shape[1])
101+
X_offset = np.zeros(X.shape[1])
105102
X_std = np.ones(X.shape[1])
106-
y_mean = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
103+
y_offset = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
107104

108-
return X, y, X_mean, y_mean, X_std
105+
return X, y, X_offset, y_offset, X_std
109106

110107

108+
@deprecated("center_data will be removed in "
109+
"0.20. Use utilities in preprocessing.data instead")
111110
def center_data(X, y, fit_intercept, normalize=False, copy=True,
112111
sample_weight=None):
113112
"""
114113
Centers data to have mean zero along axis 0. This is here because
115114
nearly all linear models will want their data to be centered.
116-
117115
If sample_weight is not None, then the weighted mean of X and y
118116
is zero, and not the mean itself
119117
"""
@@ -122,26 +120,95 @@ def center_data(X, y, fit_intercept, normalize=False, copy=True,
122120
if isinstance(sample_weight, numbers.Number):
123121
sample_weight = None
124122
if sp.issparse(X):
125-
X_mean = np.zeros(X.shape[1])
123+
X_offset = np.zeros(X.shape[1])
126124
X_std = np.ones(X.shape[1])
127125
else:
128-
X_mean = np.average(X, axis=0, weights=sample_weight)
129-
X -= X_mean
126+
X_offset = np.average(X, axis=0, weights=sample_weight)
127+
X -= X_offset
128+
# XXX: currently scaled to variance=n_samples
130129
if normalize:
131-
# XXX: currently scaled to variance=n_samples
132130
X_std = np.sqrt(np.sum(X ** 2, axis=0))
133131
X_std[X_std == 0] = 1
134132
X /= X_std
135133
else:
136134
X_std = np.ones(X.shape[1])
137-
y_mean = np.average(y, axis=0, weights=sample_weight)
138-
y = y - y_mean
135+
y_offset = np.average(y, axis=0, weights=sample_weight)
136+
y = y - y_offset
139137
else:
140-
X_mean = np.zeros(X.shape[1])
138+
X_offset = np.zeros(X.shape[1])
141139
X_std = np.ones(X.shape[1])
142-
y_mean = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
143-
return X, y, X_mean, y_mean, X_std
140+
y_offset = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
141+
return X, y, X_offset, y_offset, X_std
142+
143+
144+
def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
145+
sample_weight=None, return_mean=False):
146+
"""
147+
Centers data to have mean zero along axis 0. If fit_intercept=False or if
148+
the X is a sparse matrix, no centering is done, but normalization can still
149+
be applied. The function returns the statistics necessary to reconstruct
150+
the input data, which are X_offset, y_offset, X_scale, such that the output
151+
152+
X = (X - X_offset) / X_scale
153+
154+
X_scale is the L2 norm of X - X_offset. If sample_weight is not None,
155+
then the weighted mean of X and y is zero, and not the mean itself. If
156+
return_mean=True, the mean, eventually weighted, is returned, independently
157+
of whether X was centered (option used for optimization with sparse data in
158+
coordinate_descend).
159+
160+
This is here because nearly all linear models will want their data to be
161+
centered.
162+
"""
163+
164+
if isinstance(sample_weight, numbers.Number):
165+
sample_weight = None
166+
167+
X = check_array(X, copy=copy, accept_sparse=['csr', 'csc'],
168+
dtype=FLOAT_DTYPES)
169+
170+
if fit_intercept:
171+
if sp.issparse(X):
172+
X_offset, X_var = mean_variance_axis(X, axis=0)
173+
if not return_mean:
174+
X_offset = np.zeros(X.shape[1])
175+
176+
if normalize:
177+
178+
# TODO: f_normalize could be used here as well but the function
179+
# inplace_csr_row_normalize_l2 must be changed such that it
180+
# can return also the norms computed internally
181+
182+
# transform variance to norm in-place
183+
X_var *= X.shape[0]
184+
X_scale = np.sqrt(X_var, X_var)
185+
del X_var
186+
X_scale[X_scale == 0] = 1
187+
inplace_column_scale(X, 1. / X_scale)
188+
else:
189+
X_scale = np.ones(X.shape[1])
190+
191+
else:
192+
X_offset = np.average(X, axis=0, weights=sample_weight)
193+
X -= X_offset
194+
if normalize:
195+
X, X_scale = f_normalize(X, axis=0, copy=False,
196+
return_norm=True)
197+
else:
198+
X_scale = np.ones(X.shape[1])
199+
y_offset = np.average(y, axis=0, weights=sample_weight)
200+
y = y - y_offset
201+
else:
202+
X_offset = np.zeros(X.shape[1])
203+
X_scale = np.ones(X.shape[1])
204+
y_offset = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
205+
206+
return X, y, X_offset, y_offset, X_scale
207+
144208

209+
# TODO: _rescale_data should be factored into _preprocess_data.
210+
# Currently, the fact that sag implements its own way to deal with
211+
# sample_weight makes the refactoring tricky.
145212

146213
def _rescale_data(X, y, sample_weight):
147214
"""Rescale data so as to support sample_weight"""
@@ -200,14 +267,14 @@ def predict(self, X):
200267
"""
201268
return self._decision_function(X)
202269

203-
_center_data = staticmethod(center_data)
270+
_preprocess_data = staticmethod(_preprocess_data)
204271

205-
def _set_intercept(self, X_mean, y_mean, X_std):
272+
def _set_intercept(self, X_offset, y_offset, X_scale):
206273
"""Set the intercept_
207274
"""
208275
if self.fit_intercept:
209-
self.coef_ = self.coef_ / X_std
210-
self.intercept_ = y_mean - np.dot(X_mean, self.coef_.T)
276+
self.coef_ = self.coef_ / X_scale
277+
self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
211278
else:
212279
self.intercept_ = 0.
213280

@@ -360,6 +427,13 @@ class LinearRegression(LinearModel, RegressorMixin):
360427
361428
normalize : boolean, optional, default False
362429
If True, the regressors X will be normalized before regression.
430+
This parameter is ignored when `fit_intercept` is set to False.
431+
When the regressors are normalized, note that this makes the
432+
hyperparameters learnt more robust and almost independent of the number
433+
of samples. The same property is not valid for standardized data.
434+
However, if you wish to standardize, please use
435+
`preprocessing.StandardScaler` before calling `fit` on an estimator
436+
with `normalize=False`.
363437
364438
copy_X : boolean, optional, default True
365439
If True, X will be copied; else, it may be overwritten.
@@ -435,13 +509,12 @@ def fit(self, X, y, sample_weight=None):
435509
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'],
436510
y_numeric=True, multi_output=True)
437511

438-
if ((sample_weight is not None) and np.atleast_1d(
439-
sample_weight).ndim > 1):
440-
sample_weight = column_or_1d(sample_weight, warn=True)
512+
if sample_weight is not None and np.atleast_1d(sample_weight).ndim > 1:
513+
raise ValueError("Sample weights must be 1D array or scalar")
441514

442-
X, y, X_mean, y_mean, X_std = self._center_data(
443-
X, y, self.fit_intercept, self.normalize, self.copy_X,
444-
sample_weight=sample_weight)
515+
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
516+
X, y, fit_intercept=self.fit_intercept, normalize=self.normalize,
517+
copy=self.copy_X, sample_weight=sample_weight)
445518

446519
if sample_weight is not None:
447520
# Sample weight can be implemented via a simple rescaling.
@@ -466,7 +539,7 @@ def fit(self, X, y, sample_weight=None):
466539

467540
if y.ndim == 1:
468541
self.coef_ = np.ravel(self.coef_)
469-
self._set_intercept(X_mean, y_mean, X_std)
542+
self._set_intercept(X_offset, y_offset, X_scale)
470543
return self
471544

472545

@@ -476,15 +549,16 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy):
476549

477550
if sparse.isspmatrix(X):
478551
precompute = False
479-
X, y, X_mean, y_mean, X_std = sparse_center_data(
480-
X, y, fit_intercept, normalize)
552+
X, y, X_offset, y_offset, X_scale = _preprocess_data(
553+
X, y, fit_intercept=fit_intercept, normalize=normalize,
554+
return_mean=True)
481555
else:
482556
# copy was done in fit if necessary
483-
X, y, X_mean, y_mean, X_std = center_data(
484-
X, y, fit_intercept, normalize, copy=copy)
557+
X, y, X_offset, y_offset, X_scale = _preprocess_data(
558+
X, y, fit_intercept=fit_intercept, normalize=normalize, copy=copy)
485559
if hasattr(precompute, '__array__') and (
486-
fit_intercept and n 8F1E ot np.allclose(X_mean, np.zeros(n_features))
487-
or normalize and not np.allclose(X_std, np.ones(n_features))):
560+
fit_intercept and not np.allclose(X_offset, np.zeros(n_features)) or
561+
normalize and not np.allclose(X_scale, np.ones(n_features))):
488562
warnings.warn("Gram matrix was provided but X was centered"
489563
" to fit intercept, "
490564
"or X was normalized : recomputing Gram matrix.",
@@ -521,4 +595,4 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy):
521595
order='F')
522596
np.dot(y.T, X, out=Xy.T)
523597

524-
return X, y, X_mean, y_mean, X_std, precompute, Xy
598+
return X, y, X_offset, y_offset, X_scale, precompute, Xy

0 commit comments

Comments
 (0)
0