8000 [MRG+1] _preprocess_data consistent with fused types (#9093) · dmohns/scikit-learn@71ca4b1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 71ca4b1

Browse files
Henley13dmohns
authored andcommitted
[MRG+1] _preprocess_data consistent with fused types (scikit-learn#9093)
* add test for _preprocess_data and make it consistent * fix pep8 * add doc, cast systematically y in X.dtype and update test_coordinate_descent.py * test if input values don't change with copy=True * test if input values don't change with copy=True scikit-learn#2 * fix doc * fix doc scikit-learn#2 * fix doc scikit-learn#3
1 parent c7b1875 commit 71ca4b1

File tree

9 files changed

+89
-20
lines changed

9 files changed

+89
-20
lines changed

sklearn/linear_model/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,20 +158,21 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
158158
coordinate_descend).
159159
160160
This is here because nearly all linear models will want their data to be
161-
centered.
161+
centered. This function also systematically makes y consistent with X.dtype
162162
"""
163163

164164
if isinstance(sample_weight, numbers.Number):
165165
sample_weight = None
166166

167167
X = check_array(X, copy=copy, accept_sparse=['csr', 'csc'],
168168
dtype=FLOAT_DTYPES)
169+
y = np.asarray(y, dtype=X.dtype)
169170

170171
if fit_intercept:
171172
if sp.issparse(X):
172173
X_offset, X_var = mean_variance_axis(X, axis=0)
173174
if not return_mean:
174-
X_offset[:] = 0
175+
X_offset[:] = X.dtype.type(0)
175176

176177
if normalize:
177178

@@ -201,7 +202,10 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
201202
else:
202203
X_offset = np.zeros(X.shape[1], dtype=X.dtype)
203204
X_scale = np.ones(X.shape[1], dtype=X.dtype)
204-
y_offset = 0. if y.ndim == 1 else np.zeros(y.shape[1], dtype=X.dtype)
205+
if y.ndim == 1:
206+
y_offset = X.dtype.type(0)
207+
else:
208+
y_offset = np.zeros(y.shape[1], dtype=X.dtype)
205209

206210
return X, y, X_offset, y_offset, X_scale
207211

@@ -460,7 +464,7 @@ def fit(self, X, y, sample_weight=None):
460464
Training data
461465
462466
y : numpy array of shape [n_samples, n_targets]
463-
Target values
467+
Target values. Will be cast to X's dtype if necessary
464468
465469
sample_weight : numpy array of shape [n_samples]
466470
Individual weights for each sample

sklearn/linear_model/bayes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def fit(self, X, y):
148148
X : numpy array of shape [n_samples,n_features]
149149
Training data
150150
y : numpy array of shape [n_samples]
151-
Target values
151+
Target values. Will be cast to X's dtype if necessary
152152
153153
Returns
154154
-------
@@ -420,7 +420,7 @@ def fit(self, X, y):
420420
Training vector, where n_samples in the number of samples and
421421
n_features is the number of features.
422422
y : array, shape = [n_samples]
423-
Target values (integers)
423+
Target values (integers). Will be cast to X's dtype if necessary
424424
425425
Returns
426426
-------

sklearn/linear_model/coordinate_descent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def fit(self, X, y, check_input=True):
653653
Data
654654
655655
y : ndarray, shape (n_samples,) or (n_samples, n_targets)
656-
Target
656+
Target. Will be cast to X's dtype if necessary
657657
658658
check_input : boolean, (default=True)
659659
Allow to bypass several input checking.
@@ -1680,7 +1680,7 @@ def fit(self, X, y):
16801680
X : ndarray, shape (n_samples, n_features)
16811681
Data
16821682
y : ndarray, shape (n_samples, n_tasks)
1683-
Target
1683+
Target. Will be cast to X's dtype if necessary
16841684
16851685
Notes
16861686
-----

sklearn/linear_model/least_angle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1455,7 +1455,7 @@ def fit(self, X, y, copy_X=True):
14551455
training data.
14561456
14571457
y : array-like, shape (n_samples,)
1458-
target values.
1458+
target values. Will be cast to X's dtype if necessary
14591459
14601460
copy_X : boolean, optional, default True
14611461
If ``True``, X will be copied; else, it may be overwritten.

sklearn/linear_model/omp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def fit(self, X, y):
617617
Training data.
618618
619619
y : array-like, shape (n_samples,) or (n_samples, n_targets)
620-
Target values.
620+
Target values. Will be cast to X's dtype if necessary
621621
622622
623623
Returns
@@ -835,7 +835,7 @@ def fit(self, X, y):
835835
Training data.
836836
837837
y : array-like, shape [n_samples]
838-
Target values.
838+
Target values. Will be cast to X's dtype if necessary
839839
840840
Returns
841841
-------

sklearn/linear_model/randomized_l1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def fit(self, X, y):
8282
Training data.
8383
8484
y : array-like, shape = [n_samples]
85-
Target values.
85+
Target values. Will be cast to X's dtype if necessary
8686
8787
Returns
8888
-------

sklearn/linear_model/ridge.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,7 @@ def fit(self, X, y, sample_weight=None):
975975
Training data
976976
977977
y : array-like, shape = [n_samples] or [n_samples, n_targets]
978-
Target values
978+
Target values. Will be cast to X's dtype if necessary
979979
980980
sample_weight : float or array-like of shape [n_samples]
981981
Sample weight
@@ -1094,7 +1094,7 @@ def fit(self, X, y, sample_weight=None):
10941094
Training data
10951095
10961096
y : array-like, shape = [n_samples] or [n_samples, n_targets]
1097-
Target values
1097+
Target values. Will be cast to X's dtype if necessary
10981098
10991099
sample_weight : float or array-like of shape [n_samples]
11001100
Sample weight
@@ -1336,7 +1336,7 @@ def fit(self, X, y, sample_weight=None):
13361336
and n_features is the number of features.
13371337
13381338
y : array-like, shape (n_samples,)
1339-
Target values.
1339+
Target values. Will be cast to X's dtype if necessary
13401340
13411341
sample_weight : float or numpy array of shape (n_samples,)
13421342
Sample weight.

sklearn/linear_model/tests/test_base.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,72 @@ def test_csr_preprocess_data():
324324
assert_equal(csr_.getformat(), 'csr')
325325

326326

327+
def test_dtype_preprocess_data():
328+
n_samples = 200
329+
n_features = 2
330+
X = rng.rand(n_samples, n_features)
331+
y = rng.rand(n_samples)
332+
333+
X_32 = np.asarray(X, dtype=np.float32)
334+
y_32 = np.asarray(y, dtype=np.float32)
335+
X_64 = np.asarray(X, dtype=np.float64)
336+
y_64 = np.asarray(y, dtype=np.float64)
337+
338+
for fit_intercept in [True, False]:
339+
for normalize in [True, False]:
340+
341+
Xt_32, yt_32, X_mean_32, y_mean_32, X_norm_32 = _preprocess_data(
342+
X_32, y_32, fit_intercept=fit_intercept, normalize=normalize,
343+
return_mean=True)
344+
345+
Xt_64, yt_64, X_mean_64, y_mean_64, X_norm_64 = _preprocess_data(
346+
X_64, y_64, fit_intercept=fit_intercept, normalize=normalize,
347+
return_mean=True)
348+
349+
Xt_3264, yt_3264, X_mean_3264, y_mean_3264, X_norm_3264 = (
350+
_preprocess_data(X_32, y_64, fit_intercept=fit_intercept,
351+
normalize=normalize, return_mean=True))
352+
353+
Xt_6432, yt_6432, X_mean_6432, y_mean_6432, X_norm_6432 = (
354+
_preprocess_data(X_64, y_32, fit_intercept=fit_intercept,
355+
normalize=normalize, return_mean=True))
356+
357+
assert_equal(Xt_32.dtype, np.float32)
358+
assert_equal(yt_32.dtype, np.float32)
359+
assert_equal(X_mean_32.dtype, np.float32)
360+
assert_equal(y_mean_32.dtype, np.float32)
361+
assert_equal(X_norm_32.dtype, np.float32)
362+
363+
assert_equal(Xt_64.dtype, np.float64)
364+
assert_equal(yt_64.dtype, np.float64)
365+
assert_equal(X_mean_64.dtype, np.float64)
366+
assert_equal(y_mean_64.dtype, np.float64)
367+
assert_equal(X_norm_64.dtype, np.float64)
368+
369+
assert_equal(Xt_3264.dtype, np.float32)
370+
assert_equal(yt_3264.dtype, np.float32)
371+
assert_equal(X_mean_3264.dtype, np.float32)
372+
assert_equal(y_mean_3264.dtype, np.float32)
373+
assert_equal(X_norm_3264.dtype, np.float32)
374+
375+
assert_equal(Xt_6432.dtype, np.float64)
376+
assert_equal(yt_6432.dtype, np.float64)
377+
assert_equal(X_mean_6432.dtype, np.float64)
378+
assert_equal(y_mean_6432.dtype, np.float64)
379+
assert_equal(X_norm_6432.dtype, np.float64)
380+
381+
assert_equal(X_32.dtype, np.float32)
382+
assert_equal(y_32.dtype, np.float32)
383+
assert_equal(X_64.dtype, np.float64)
384+
assert_equal(y_64.dtype, np.float64)
385+
386+
assert_array_almost_equal(Xt_32, Xt_64)
387+
assert_array_almost_equal(yt_32, yt_64)
388+
assert_array_almost_equal(X_mean_32, X_mean_64)
389+
assert_array_almost_equal(y_mean_32, y_mean_64)
390+
assert_array_almost_equal(X_norm_32, X_norm_64)
391+
392+
327393
def test_rescale_data():
328394
n_samples = 200
329395
n_features = 2

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -661,12 +661,11 @@ def test_check_input_false():
661661
clf = ElasticNet(selection='cyclic', tol=1e-8)
662662
# Check that no error is raised if data is provided in the right format
663663
clf.fit(X, y, check_input=False)
664+
# With check_input=False, an exhaustive check is not made on y but its
665+
# dtype is still cast in _preprocess_data to X's dtype. So the test should
666+
# pass anyway
664667
X = check_array(X, order='F', dtype='float32')
665-
clf.fit(X, y, check_input=True)
666-
# Check that an error is raised if data is provided in the wrong dtype,
667-
# because of check bypassing
668-
assert_raises(ValueError, clf.fit, X, y, check_input=False)
669-
668+
clf.fit(X, y, check_input=False)
670669
# With no input checking, providing X in C order should result in false
671670
# computation
672671
X = check_array(X, order='C', dtype='float64')

0 commit comments

Comments
 (0)
0