8000 CLN Only check for n_features_in_ only when it exists (#18011) · scikit-learn/scikit-learn@e8ffa31 · GitHub
[go: up one dir, main page]

Skip to content

Commit e8ffa31

Browse files
thomasjpfanogrisel
andauthored
CLN Only check for n_features_in_ only when it exists (#18011)
* CLN Checks n_features_in only if it exists * Update sklearn/base.py Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com> * DOC Update docstring * DOC Grammer * Grammar [ci skip] Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com>
1 parent fdb9233 commit e8ffa31

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

sklearn/base.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,9 @@ def _check_n_features(self, X, reset):
359359
The input samples.
360360
reset : bool
361361
If True, the `n_features_in_` attribute is set to `X.shape[1]`.
362-
Else, the attribute must already exist and the function checks
363-
that it is equal to `X.shape[1]`.
362+
If False and the attribute exists, then check that it is equal to
363+
`X.shape[1]`. If False and the attribute does *not* exist, then
364+
the check is skipped.
364365
.. note::
365366
It is recommended to call reset=True in `fit` and in the first
366367
call to `partial_fit`. All other methods that validate `X`
@@ -370,18 +371,18 @@ def _check_n_features(self, X, reset):
370371

371372
if reset:
372373
self.n_features_in_ = n_features
373-
else:
374-
if not hasattr(self, 'n_features_in_'):
375-
raise RuntimeError(
376-
"The reset parameter is False but there is no "
377-
"n_features_in_ attribute. Is this estimator fitted?"
378-
)
379-
if n_features != self.n_features_in_:
380-
raise ValueError(
381-
'X has {} features, but {} is expecting {} features '
382-
'as input.'.format(n_features, self.__class__.__name__,
383-
self.n_features_in_)
384-
)
374+
return
375+
376+
if not hasattr(self, "n_features_in_"):
377+
# Skip this check if the expected number of expected input features
378+
# was not recorded by calling fit first. This is typically the case
379+
# for stateless transformers.
380+
return
381+
382+
if n_features != self.n_features_in_:
383+
raise ValueError(
384+
f"X has {n_features} features, but {self.__class__.__name__} "
385+
f"is expecting {self.n_features_in_} features as input.")
385386

386387
def _validate_data(self, X, y='no_validation', reset=True,
387388
validate_separately=False, **check_params):

0 commit comments

Comments
 (0)
0