8000 ENH Checks n_features_in_ in preprocessing module (#18577) · scikit-learn/scikit-learn@d933c20 · GitHub
[go: up one dir, main page]

Skip to content
65FA

Commit d933c20

Browse files
thomasjpfanogriseljnothmanlorentzenchrNicolasHug
authored
ENH Checks n_features_in_ in preprocessing module (#18577)
Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com> Co-authored-by: Joel Nothman <joel.nothman@gmail.com> Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 90b9b5d commit d933c20

File tree

5 files changed

+35
-59
lines changed

5 files changed

+35
-59
lines changed

sklearn/preprocessing/_data.py

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,8 @@ def transform(self, X):
432432
"""
433433
check_is_fitted(self)
434434

435-
X = check_array(X, copy=self.copy, dtype=FLOAT_DTYPES,
436-
force_all_finite="allow-nan")
435+
X = self._validate_data(X, copy=self.copy, dtype=FLOAT_DTYPES,
436+
force_all_finite="allow-nan", reset=False)
437437

438438
X *= self.scale_
439439
X += self.min_
@@ -760,9 +760,10 @@ def partial_fit(self, X, y=None, sample_weight=None):
760760
self : object
761761
Fitted scaler.
762762
"""
763+
first_call = not hasattr(self, "n_samples_seen_")
763764
X = self._validate_data(X, accept_sparse=('csr', 'csc'),
764765
estimator=self, dtype=FLOAT_DTYPES,
765-
force_all_finite='allow-nan')
766+
force_all_finite='allow-nan', reset=first_call)
766767

767768
if sample_weight is not None:
768769
sample_weight = _check_sample_weight(sample_weight, X,
@@ -1097,9 +1098,10 @@ def transform(self, X):
10971098
Transformed array.
10981099
"""
10991100
check_is_fitted(self)
1100-
X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy,
1101-
estimator=self, dtype=FLOAT_DTYPES,
1102-
force_all_finite='allow-nan')
1101+
X = self._validate_data(X, accept_sparse=('csr', 'csc'),
1102+
copy=self.copy, reset=False,
1103+
estimator=self, dtype=FLOAT_DTYPES,
1104+
force_all_finite='allow-nan')
11031105

11041106
if sparse.issparse(X):
11051107
inplace_column_scale(X, 1.0 / self.scale_)
@@ -1398,9 +1400,10 @@ def transform(self, X):
13981400
Transformed array.
13991401
"""
14001402
check_is_fitted(self)
1401-
X = check_array(X, accept_sparse=('csr', 'csc'), copy=self.copy,
1402-
estimator=self, dtype=FLOAT_DTYPES,
1403-
force_all_finite='allow-nan')
1403+
X = self._validate_data(X, accept_sparse=('csr', 'csc'),
1404+
copy=self.copy, estimator=self,
1405+
dtype=FLOAT_DTYPES, reset=False,
1406+
force_all_finite='allow-nan')
14041407

14051408
if sparse.issparse(X):
14061409
if self.with_scaling:
@@ -1735,8 +1738,8 @@ def transform(self, X):
17351738
"""
17361739
check_is_fitted(self)
17371740

1738-
X = check_array(X, order='F', dtype=FLOAT_DTYPES,
1739-
accept_sparse=('csr', 'csc'))
1741+
X = self._validate_data(X, order='F', dtype=FLOAT_DTYPES, reset=False,
1742+
accept_sparse=('csr', 'csc'))
17401743

17411744
n_samples, n_features = X.shape
17421745

@@ -2038,7 +2041,7 @@ def transform(self, X, copy=None):
20382041
Transformed array.
20392042
"""
20402043
copy = copy if copy is not None else self.copy
2041-
X = check_array(X, accept_sparse='csr')
2044+
X = self._validate_data(X, accept_sparse='csr', reset=False)
20422045
return normalize(X, norm=self.norm, axis=1, copy=copy)
20432046

20442047
def _more_tags(self):
@@ -2195,7 +2198,11 @@ def transform(self, X, copy=None):
21952198
Transformed array.
21962199
"""
21972200
copy = copy if copy is not None else self.copy
2198-
return binarize(X, threshold=self.threshold, copy=copy)
2201+
# TODO: This should be refactored because binarize also calls
2202+
# check_array
2203+
X = self._validate_data(X, accept_sparse=['csr', 'csc'], copy=copy,
2204+
reset=False)
2205+
return binarize(X, threshold=self.threshold, copy=False)
21992206

22002207
def _more_tags(self):
22012208
return {'stateless': True}
@@ -2291,7 +2298,7 @@ def transform(self, K, copy=True):
22912298
"""
22922299
check_is_fitted(self)
22932300

2294-
K = check_array(K, copy=copy, dtype=FLOAT_DTYPES)
2301+
K = self._validate_data(K, copy=copy, dtype=FLOAT_DTYPES, reset=False)
22952302

22962303
K_pred_cols = (np.sum(K, axis=1) /
22972304
self.K_fit_rows_.shape[0])[:, np.newaxis]
@@ -2689,16 +2696,7 @@ def _transform_col(self, X_col, quantiles, inverse):
26892696
def _check_inputs(self, X, in_fit, accept_sparse_negative=False,
26902697
copy=False):
26912698
"""Check inputs before fit and transform."""
2692-
# In theory reset should be equal to `in_fit`, but there are tests
2693-
# checking the input number of feature and they expect a specific
2694-
# string, which is not the same one raised by check_n_features. So we
2695-
# don't check n_features_in_ here for now (it's done with adhoc code in
2696-
# the estimator anyway).
2697-
# TODO: set reset=in_fit when addressing reset in
2698-
# predict/transform/etc.
2699-
reset = True
2700-
2701-
X = self._validate_data(X, reset=reset,
2699+
X = self._validate_data(X, reset=in_fit,
27022700
accept_sparse='csc', copy=copy,
27032701
dtype=FLOAT_DTYPES,
27042702
force_all_finite='allow-nan')
@@ -2718,16 +2716,6 @@ def _check_inputs(self, X, in_fit, accept_sparse_negative=False,
27182716

27192717
return X
27202718

2721-
def _check_is_fitted(self, X):
2722-
"""Check the inputs before transforming."""
2723-
check_is_fitted(self)
2724-
# check that the dimension of X are adequate with the fitted data
2725-
if X.shape[1] != self.quantiles_.shape[1]:
2726-
raise ValueError('X does not have the same number of features as'
2727-
' the previously fitted data. Got {} instead of'
2728-
' {}.'.format(X.shape[1],
2729-
self.quantiles_.shape[1]))
2730-
27312719
def _transform(self, X, inverse=False):
27322720
"""Forward and inverse transform.
27332721
@@ -2777,8 +2765,8 @@ def transform(self, X):
27772765
Xt : {ndarray, sparse matrix} of shape (n_samples, n_features)
27782766
The projected data.
27792767
"""
2768+
check_is_fitted(self)
27802769
X = self._check_inputs(X, in_fit=False, copy=self.copy)
2781-
self._check_is_fitted(X)
27822770

27832771
return self._transform(X, inverse=False)
27842772

@@ -2798,9 +2786,9 @@ def inverse_transform(self, X):
27982786
Xt : {ndarray, sparse matrix} of (n_samples, n_features)
27992787
The projected data.
28002788
"""
2789+
check_is_fitted(self)
28012790
X = self._check_inputs(X, in_fit=False, accept_sparse_negative=True,
28022791
copy=self.copy)
2803-
self._check_is_fitted(X)
28042792

28052793
return self._transform(X, inverse=True)
28062794

@@ -3262,6 +3250,10 @@ def _check_input(self, X, in_fit, check_positive=False, check_shape=False,
32623250
----------
32633251
X : array-like of shape (n_samples, n_features)
32643252
3253+
in_fit : bool
3254+
Whether or not `_check_input` is called from `fit` or other
3255+
methods, e.g. `predict`, `transform`, etc.
3256+
32653257
check_positive : bool, default=False
32663258
If True, check that all data is positive and non-zero (only if
32673259
``self.method=='box-cox'``).
@@ -3273,7 +3265,8 @@ def _check_input(self, X, in_fit, check_positive=False, check_shape=False,
32733265
If True, check that the transformation method is valid.
32743266
"""
32753267
X = self._validate_data(X, ensure_2d=True, dtype=FLOAT_DTYPES,
3276-
copy=self.copy, force_all_finite='allow-nan')
3268+
copy=self.copy, force_all_finite='allow-nan',
3269+
reset=in_fit)
32773270

32783271
with np.warnings.catch_warnings():
32793272
np.warnings.filterwarnings(

sklearn/preprocessing/_discretization.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,7 @@ def transform(self, X):
289289

290290
# check input and attribute dtypes
291291
dtype = (np.float64, np.float32) if self.dtype is None else self.dtype
292-
Xt = check_array(X, copy=True, dtype=dtype)
293-
294-
n_features = self.n_bins_.shape[0]
295-
if Xt.shape[1] != n_features:
296-
raise ValueError("Incorrect number of features. Expecting {}, "
297-
"received {}.".format(n_features, Xt.shape[1]))
292+
Xt = self._validate_data(X, copy=True, dtype=dtype, reset=False)
298293

299294
bin_edges = self.bin_edges_
300295
for jj in range(Xt.shape[1]):

sklearn/preprocessing/tests/test_data.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,12 +1310,8 @@ def test_quantile_transform_check_error():
13101310

13111311
X_bad_feat = np.transpose([[0, 25, 50, 0, 0, 0, 75, 0, 0, 100],
13121312
[0, 0, 2.6, 4.1, 0, 0, 2.3, 0, 9.5, 0.1]])
1313-
err_msg = ("X does not have the same number of features as the previously"
1314-
" fitted " "data. Got 2 instead of 3.")
1315-
with pytest.raises(ValueError, match=err_msg):
1316-
transformer.transform(X_bad_feat)
1317-
err_msg = ("X does not have the same number of features "
1318-
"as the previously fitted data. Got 2 instead of 3.")
1313+
err_msg = ("X has 2 features, but QuantileTransformer is expecting "
1314+
"3 features as input.")
13191315
with pytest.raises(ValueError, match=err_msg):
13201316
transformer.inverse_transform(X_bad_feat)
13211317

@@ -2434,7 +2430,8 @@ def test_power_transformer_shape_exception(method):
24342430

24352431
# Exceptions should be raised for arrays with different num_columns
24362432
# than during fitting
2437-
wrong_shape_message = 'Input data has a different number of features'
2433+
wrong_shape_message = (r"X has \d+ features, but PowerTransformer is "
2434+
r"expecting \d+ features")
24382435

24392436
with pytest.raises(ValueError, match=wrong_shape_message):
24402437
pt.transform(X[:, 0:1])

sklearn/preprocessing/tests/test_discretization.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,6 @@ def test_fit_transform_n_bins_array(strategy, expected):
101101
assert bin_edges.shape == (n_bins + 1, )
102102

103103

104-
def test_invalid_n_features():
105-
est = KBinsDiscretizer(n_bins=3).fit(X)
106-
bad_X = C188 np.arange(25).reshape(5, -1)
107-
err_msg = "Incorrect number of features. Expecting 4, received 5"
108-
with pytest.raises(ValueError, match=err_msg):
109-
est.transform(bad_X)
110-
111-
112104
@pytest.mark.parametrize('strategy', ['uniform', 'kmeans', 'quantile'])
113105
def test_same_min_max(strategy):
114106
warnings.simplefilter("always")

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ def test_search_cv(estimator, check, request):
358358
'naive_bayes',
359359
'neighbors',
360360
'pipeline',
361-
'preprocessing',
362361
'random_projection',
363362
'semi_supervised',
364363
'svm',

0 commit comments

Comments
 (0)
0