8000 refactoring and deprecation of center_data · scikit-learn/scikit-learn@c29c00b · GitHub
[go: up one dir, main page]

Skip to content

Commit c29c00b

Browse files
author
giorgiop
committed
refactoring and deprecation of center_data
1 parent f459c99 commit c29c00b

File tree

11 files changed

+438
-181
lines changed

11 files changed

+438
-181
lines changed

doc/datasets/rcv1.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Each sample can be identified by its ID, ranging (with gaps) from 2286 to 810596
3535
array([2286, 2287, 2288], dtype=int32)
3636

3737
``target_names``:
38-
The target values are the topics of each sample. Each sample belongs to at least one topic, and to up to 17 topics.
38+
The target values are the topics of each sample. Each sample belongs to at least one topic, and to up to 17 topics.
3939
There are 103 topics, each represented by a string. Their corpus frequencies span five orders of magnitude, from 5 occurrences for 'GMIL', to 381327 for 'CCAT'::
4040

4141
>>> rcv1.target_names[:3].tolist() # doctest: +SKIP

sklearn/linear_model/base.py

+109-42
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# Mathieu Blondel <mathieu@mblondel.org>
1111
# Lars Buitinck <L.J.Buitinck@uva.nl>
1212
# Maryan Morel <maryan.morel@polytechnique.edu>
13-
#
13+
# Giorgio Patrini <giorgio.patrini@anu.edu.au>
1414
# License: BSD 3 clause
1515

1616
from __future__ import division
@@ -26,19 +26,16 @@
2626
from ..externals import six
2727
from ..externals.joblib import Parallel, delayed
2828
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
29-
from ..utils import as_float_array, check_array, check_X_y, deprecated
29+
from ..utils import check_array, check_X_y, deprecated, as_float_array
30+
from ..utils.validation import FLOAT_DTYPES
3031
from ..utils import check_random_state
3132
from ..utils.extmath import safe_sparse_dot
3233
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
3334
from ..utils.fixes import sparse_lsqr
3435
from ..utils.seq_dataset import ArrayDataset, CSRDataset
3536
from ..utils.validation import check_is_fitted
3637
from ..exceptions import NotFittedError
37-
38-
#
39-
# TODO: intercept for all models
40-
# We should define a common function to center data instead of
41-
# repeating the same code inside each fit method.
38+
from ..preprocessing.data import normalize as f_normalize
4239

4340
# TODO: bayesian_ridge_regression and bayesian_regression_ard
4441
# should be squashed into its respective objects.
@@ -70,49 +67,51 @@ def make_dataset(X, y, sample_weight, random_state=None):
7067
return dataset, intercept_decay
7168

7269

70+
@deprecated("sparse_center_data will be removed in "
71+
"0.20. Use utilities in preprocessing.data instead")
7372
def sparse_center_data(X, y, fit_intercept, normalize=False):
7473
"""
7574
Compute information needed to center data to have mean zero along
7675
axis 0. Be aware that X will not be centered since it would break
7776
the sparsity, but will be normalized if asked so.
7877
"""
79-
# We might require not to change the csr matrix sometimes
80-
# Store a copy if normalize is True.
81-
# Change dtype to float64 since mean_variance_axis accepts
82-
# it that way.
8378
if fit_intercept:
79+
# we might require not to change the csr matrix sometimes
80+
# store a copy if normalize is True.
81+
# Change dtype to float64 since mean_variance_axis accepts
82+
# it that way.
8483
if sp.isspmatrix(X) and X.getformat() == 'csr':
8584
X = sp.csr_matrix(X, copy=normalize, dtype=np.float64)
8685
else:
8786
X = sp.csc_matrix(X, copy=normalize, dtype=np.float64)
8887

8988
X_mean, X_var = mean_variance_axis(X, axis=0)
9089
if normalize:
91-
# transform variance to norm in-place
92-
# XXX: currently scaled to variance=n_samples to match center_data
90+
# transform variance to std in-place
9391
X_var *= X.shape[0]
94-
X_norm = np.sqrt(X_var, X_var)
92+
X_std = np.sqrt(X_var, X_var)
9593
del X_var
96-
X_norm[X_norm == 0] = 1
97-
inplace_column_scale(X, 1. / X_norm)
94+
X_std[X_std == 0] = 1
95+
inplace_column_scale(X, 1. / X_std)
9896
else:
99-
X_norm = np.ones(X.shape[1])
97+
X_std = np.ones(X.shape[1])
10098
y_mean = y.mean(axis=0)
10199
y = y - y_mean
102100
else:
103101
X_mean = np.zeros(X.shape[1])
104-
X_norm = np.ones(X.shape[1])
102+
X_std = np.ones(X.shape[1])
105103
y_mean = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
106104

107-
return X, y, X_mean, y_mean, X_norm
105+
return X, y, X_mean, y_mean, X_std
108106

109107

108+
@deprecated("center_data will be removed in "
109+
"0.20. Use utilities in preprocessing.data instead")
110110
def center_data(X, y, fit_intercept, normalize=False, copy=True,
111111
sample_weight=None):
112112
"""
113113
Centers data to have mean zero along axis 0. This is here because
114114
nearly all linear models will want their data to be centered.
115-
116115
If sample_weight is not None, then the weighted mean of X and y
117116
is zero, and not the mean itself
118117
"""
@@ -122,33 +121,97 @@ def center_data(X, y, fit_intercept, normalize=False, copy=True,
122121
sample_weight = None
123122
if sp.issparse(X):
124123
X_mean = np.zeros(X.shape[1])
125-
X_norm = np.ones(X.shape[1])
124+
X_std = np.ones(X.shape[1])
126125
else:
127126
X_mean = np.average(X, axis=0, weights=sample_weight)
128127
X -= X_mean
129128
if normalize:
130-
# XXX: currently scaled to variance=n_samples
131-
X_norm = np.sqrt(np.sum(X ** 2, axis=0))
132-
X_norm[X_norm == 0] = 1
133-
X /= X_norm
129+
X_std = np.sqrt(np.sum(X ** 2, axis=0))
130+
X_std[X_std == 0] = 1
131+
X /= X_std
134132
else:
135-
X_norm = np.ones(X.shape[1])
133+
X_std = np.ones(X.shape[1])
136134
y_mean = np.average(y, axis=0, weights=sample_weight)
137135
y = y - y_mean
138136
else:
139137
X_mean = np.zeros(X.shape[1])
140-
X_norm = np.ones(X.shape[1])
138+
X_std = np.ones(X.shape[1])
141139
y_mean = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
140+
return X, y, X_mean, y_mean, X_std
141+
142+
143+
def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
144+
sample_weight=None, return_mean=False):
145+
"""
146+
Centers data to have mean zero along axis 0. If fit_intercept=False or if
147+
the X is a sparse matrix, no centering is done, but normalization can still
148+
be applied. The function returns the statistics necessary to reconstruct
149+
the input data, which are X_offset, y_offset, X_scale, such that the output
150+
151+
X = (X - X_offset) / X_scale
152+
153+
If sample_weight is not None, then the weighted mean of X and y
154+
is zero, and not the mean itself. If return_mean=True, the mean, eventually
155+
weighted, is returned, independently of whether X was centered (option used
156+
for optimization with sparse data in coordinate_descend).
157+
158+
This is here because nearly all linear models will want their data to be
159+
centered.
160+
"""
142161

143-
return X, y, X_mean, y_mean, X_norm
162+
if isinstance(sample_weight, numbers.Number):
163+
sample_weight = None
144164

165+
X = check_array(X, copy=copy, accept_sparse=['csr', 'csc'],
166+
dtype=FLOAT_DTYPES)
167+
168+
if fit_intercept:
169+
1CF5 if sp.issparse(X):
170+
X_offset, X_var = mean_variance_axis(X, axis=0)
171+
if not return_mean:
172+
X_offset = np.zeros(X.shape[1])
173+
174+
if normalize:
175+
176+
# TODO: f_normalize could be used here as well but the function
177+
# inplace_csr_row_normalize_l2 must be changed such that it
178+
# can return also the norms computed internally
179+
180+
# transform variance to norm in-place
181+
X_var *= X.shape[0]
182+
X_scale = np.sqrt(X_var, X_var)
183+
del X_var
184+
X_scale[X_scale == 0] = 1
185+
inplace_column_scale(X, 1. / X_scale)
186+
else:
187+
X_scale = np.ones(X.shape[1])
188+
189+
else:
190+
X_offset = np.average(X, axis=0, weights=sample_weight)
191+
X -= X_offset
192+
if normalize:
193+
X, X_scale = f_normalize(X, axis=0, copy=False,
194+
return_norm=True)
195+
else:
196+
X_scale = np.ones(X.shape[1])
197+
y_offset = np.average(y, axis=0, weights=sample_weight)
198+
y = y - y_offset
199+
else:
200+
X_offset = np.zeros(X.shape[1])
201+
X_scale = np.ones(X.shape[1])
202+
y_offset = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
203+
204+
return X, y, X_offset, y_offset, X_scale
205+
206+
207+
# TODO: _rescale_data should be factored into _preprocess_data.
208+
# Currently, the fact that sag implements its own way to deal with
209+
# sample_weight makes the refactoring tricky.
145210

146211
def _rescale_data(X, y, sample_weight):
147212
"""Rescale data so as to support sample_weight"""
213+
sample_weight = sample_weight * np.ones(y.shape[0])
148214
sample_weight = np.sqrt(sample_weight)
149-
if not isinstance(sample_weight, np.ndarray): # scalar case
150-
sample_weight = sample_weight * np.ones(y.shape[0])
151-
152215
sw_matrix = np.diag(sample_weight)
153216
if sp.issparse(X) or sp.issparse(y):
154217
sw_matrix = sparse.dia_matrix(sw_matrix)
@@ -202,13 +265,13 @@ def predict(self, X):
202265
"""
203266
return self._decision_function(X)
204267

205-
_center_data = staticmethod(center_data)
268+
_preprocess_data = staticmethod(_preprocess_data)
206269

207270
def _set_intercept(self, X_mean, y_mean, X_norm):
208271
"""Set the intercept_
209272
"""
210-
self.coef_ = self.coef_ / X_norm
211273
if self.fit_intercept:
274+
self.coef_ = self.coef_ / X_norm
212275
self.intercept_ = y_mean - np.dot(X_mean, self.coef_.T)
213276
else:
214277
self.intercept_ = 0.
@@ -362,9 +425,13 @@ class LinearRegression(LinearModel, RegressorMixin):
362425
363426
normalize : boolean, optional, default False
364427
If True, the regressors X will be normalized before regression.
365-
Normalization makes the `coef_` independent from the number of training
366-
samples. If you wish to standardize instead, please use
367-
`preprocessing.StandardScaler` before calling `fit`.
428+
When the regressors are normalized, the fitted `coef_` are the same
429+
independently of the number of training samples; hence, hyperparameters
430+
learnt by cross-validation will be compatible among different training
431+
and validation sets. The same property is not valid for standardized
432+
data. However, if you wish to standardize, please use
433+
`preprocessing.StandardScaler` before calling `fit` on an estimator
434+
with `normalize=False`.
368435
369436
copy_X : boolean, optional, default True
370437
If True, X will be copied; else, it may be overwritten.
@@ -440,11 +507,10 @@ def fit(self, X, y, sample_weight=None):
440507
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'],
441508
y_numeric=True, multi_output=True)
442509

443-
if ((sample_weight is not None) and
444-
np.atleast_1d(sample_weight).ndim > 1):
510+
if sample_weight is not None and np.atleast_1d(sample_weight).ndim > 1:
445511
raise ValueError("Sample weights must be 1D array or scalar")
446512

447-
X, y, X_mean, y_mean, X_norm = self._center_data(
513+
X, y, X_mean, y_mean, X_norm = self._preprocess_data(
448514
X, y, fit_intercept=self.fit_intercept, normalize=self.normalize,
449515
copy=self.copy_X, sample_weight=sample_weight)
450516

@@ -481,11 +547,12 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy):
481547

482548
if sparse.isspmatrix(X):
483549
precompute = False
484-
X, y, X_mean, y_mean, X_norm = sparse_center_data(
485-
X, y, fit_intercept=fit_intercept, normalize=normalize)
550+
X, y, X_mean, y_mean, X_norm = _preprocess_data(
551+
X, y, fit_intercept=fit_intercept, normalize=normalize,
552+
return_mean=True)
486553
else:
487554
# copy was done in fit if necessary
488-
X, y, X_mean, y_mean, X_norm = center_data(
555+
X, y, X_mean, y_mean, X_norm = _preprocess_data(
489556
X, y, fit_intercept=fit_intercept, normalize=normalize, copy=copy)
490557
if hasattr(precompute, '__array__') and (
491558
fit_intercept and not np.allclose(X_mean, np.zeros(n_features)) or

sklearn/linear_model/bayes.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,13 @@ class BayesianRidge(LinearModel, RegressorMixin):
6565
6666
normalize : boolean, optional, default False
6767
If True, the regressors X will be normalized before regression.
68-
Normalization makes the `coef_` independent from the number of training
69-
samples. If you wish to standardize instead, please use
70-
`preprocessing.StandardScaler` before calling `fit`.
68+
When the regressors are normalized, the fitted `coef_` are the same
69+
independently of the number of training samples; hence, hyperparameters
70+
learnt by cross-validation will be compatible among different training
71+
and validation sets. The same property is not valid for standardized
72+
data. However, if you wish to standardize, please use
73+
`preprocessing.StandardScaler` before calling `fit` on an estimator
74+
with `normalize=False`.
7175
7276
copy_X : boolean, optional, default True
7377
If True, X will be copied; else, it may be overwritten.
@@ -138,7 +142,7 @@ def fit(self, X, y):
138142
self : returns an instance of self.
139143
"""
140144
X, y = check_X_y(X, y, dtype=np.float64, y_numeric=True)
141-
X, y, X_mean, y_mean, X_std = self._center_data(
145+
X, y, X_mean, y_mean, X_std = self._preprocess_data(
142146
X, y, self.fit_intercept, self.normalize, self.copy_X)
143147
n_samples, n_features = X.shape
144148

@@ -272,9 +276,13 @@ class ARDRegression(LinearModel, RegressorMixin):
272276
273277
normalize : boolean, optional, default False
274278
If True, the regressors X will be normalized before regression.
275-
Normalization makes the `coef_` independent from the number of training
276-
samples. If you wish to standardize instead, please use
277-
`preprocessing.StandardScaler` before calling `fit`.
279+
When the regressors are normalized, the fitted `coef_` are the same
280+
independently of the number of training samples; hence, hyperparameters
281+
learnt by cross-validation will be compatible among different training
282+
and validation sets. The same property is not valid for standardized
283+
data. However, if you wish to standardize, please use
284+
`preprocessing.StandardScaler` before calling `fit` on an estimator
285+
with `normalize=False`.
278286
279287
copy_X : boolean, optional, default True.
280288
If True, X will be copied; else, it may be overwritten.
@@ -357,7 +365,7 @@ def fit(self, X, y):
357365
n_samples, n_features = X.shape
358366
coef_ = np.zeros(n_features)
359367

360-
X, y, X_mean, y_mean, X_std = self._center_data(
368+
X, y, X_mean, y_mean, X_std = self._preprocess_data(
361369
X, y, self.fit_intercept, self.normalize, self.copy_X)
362370

363371
# Launch the convergence loop

0 commit comments

Comments
 (0)
0