From 6eb99056bf1fcb08a161f2aca1c6660a3d62ed72 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 26 Jul 2020 23:14:56 -0400 Subject: [PATCH 1/5] CLN Checks n_features_in only if it exists --- sklearn/base.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 46398baabfd3a..1900edab2b4f0 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -363,18 +363,18 @@ def _check_n_features(self, X, reset): if reset: self.n_features_in_ = n_features - else: - if not hasattr(self, 'n_features_in_'): - raise RuntimeError( - "The reset parameter is False but there is no " - "n_features_in_ attribute. Is this estimator fitted?" - ) - if n_features != self.n_features_in_: - raise ValueError( - 'X has {} features, but this {} is expecting {} features ' - 'as input.'.format(n_features, self.__class__.__name__, - self.n_features_in_) - ) + return + + fitted_n_features_in = getattr(self, 'n_features_in_', None) + if fitted_n_features_in is None: + return + + if n_features != fitted_n_features_in: + raise ValueError( + 'X has {} features, but this {} is expecting {} features ' + 'as input.'.format(n_features, self.__class__.__name__, + fitted_n_features_in) + ) def _validate_data(self, X, y=None, reset=True, validate_separately=False, **check_params): From d71201f08de903099e90d96d45201eccf3c144db Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 12 Oct 2020 17:08:20 -0400 Subject: [PATCH 2/5] Update sklearn/base.py Co-authored-by: Olivier Grisel --- sklearn/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index d96fa35d3b0af..32cd874fe7598 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -372,8 +372,10 @@ def _check_n_features(self, X, reset): self.n_features_in_ = n_features return - fitted_n_features_in = getattr(self, 'n_features_in_', None) - if fitted_n_features_in is None: + if not hasattr(self, "n_features_in_"): + # Skip this check if the expected number of expected input features + # was not recorded by calling fit first. This is typically the case + # for stateless transformers. return if n_features != self.n_features_in_: From edb6029dc96c0f63a2f229a2fcdeda3b5b2d933e Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 12 Oct 2020 17:12:00 -0400 Subject: [PATCH 3/5] DOC Update docstring --- sklearn/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 32cd874fe7598..f8ce624801e4c 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -359,8 +359,9 @@ def _check_n_features(self, X, reset): The input samples. reset : bool If True, the `n_features_in_` attribute is set to `X.shape[1]`. - Else, the attribute must already exist and the function checks - that it is equal to `X.shape[1]`. + If False and the attribute exists, then check that it is equal to + `X.shape[1]`. If False and the attribute does *not* exists, then + the check is skipped. .. note:: It is recommended to call reset=True in `fit` and in the first call to `partial_fit`. All other methods that validate `X` From 833f1856535a7add5b186b77f81d493873333cf2 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 12 Oct 2020 17:37:59 -0400 Subject: [PATCH 4/5] DOC Grammer --- sklearn/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index f8ce624801e4c..fe7d61740cfd4 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -359,8 +359,8 @@ def _check_n_features(self, X, reset): The input samples. reset : bool If True, the `n_features_in_` attribute is set to `X.shape[1]`. - If False and the attribute exists, then check that it is equal to - `X.shape[1]`. If False and the attribute does *not* exists, then + If False and the attribute exist, then check that it is equal to + `X.shape[1]`. If False and the attribute does *not* exist, then the check is skipped. .. note:: It is recommended to call reset=True in `fit` and in the first From 2e0908955e010d3f240b659442b645b992c4e9c2 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 13 Oct 2020 11:06:41 +0200 Subject: [PATCH 5/5] Grammar [ci skip] --- sklearn/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/base.py b/sklearn/base.py index fe7d61740cfd4..bb2e3c67d7bbc 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -359,7 +359,7 @@ def _check_n_features(self, X, reset): The input samples. reset : bool If True, the `n_features_in_` attribute is set to `X.shape[1]`. - If False and the attribute exist, then check that it is equal to + If False and the attribute exists, then check that it is equal to `X.shape[1]`. If False and the attribute does *not* exist, then the check is skipped. .. note::