8000 Common tests · scikit-learn/scikit-learn@41c2596 · GitHub
[go: up one dir, main page]

Skip to content

Commit 41c2596

Browse files
committed
Common tests
1 parent 5c4ef14 commit 41c2596

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

sklearn/preprocessing/imputation.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,9 @@ class MissingIndicator(BaseEstimator, TransformerMixin):
405405
The features with missing values.
406406
Note that this is only stored if features == 'train
407407
408+
n_features_ : int
409+
The number of features during fit time.
410+
408411
Example
409412
-------
410413
>>> from sklearn.preprocessing import MissingIndicator
@@ -436,7 +439,7 @@ def __init__(self, missing_values="NaN", features="train", sparse="auto"):
436439
self.features = features
437440
self.sparse = sparse
438441

439-
def fit(self, X):
442+
def fit(self, X, y=None):
440443
"""Fit the transformer on X.
441444
442445
Parameters
@@ -463,8 +466,8 @@ def fit(self, X):
463466
raise ValueError("sparse can only use be boolean or 'auto'"
464467
" got {0}".format(self.sparse))
465468

466-
X = check_array(X, accept_sparse=('csc', 'csr'), dtype=np.float64,
467-
force_all_finite=False)
469+
X = check_array(X, accept_sparse=('csc', 'csr'), dtype=np.float64)
470+
self.n_features_ = X.shape[1]
468471

469472
if self.features == "train":
470473
_, self.feat_with_missing_ = self._get_missing_features_info(X)
@@ -488,14 +491,16 @@ def transform(self, X):
488491
if self.features == "train":
489492
check_is_fitted(self, "feat_with_missing_")
490493

491-
X = check_array(X, accept_sparse=('csc', 'csr'), dtype=np.float64,
492-
force_all_finite=False)
494+
X = check_array(X, accept_sparse=('csc', 'csr'), dtype=np.float64)
495+
if X.shape[1] != self.n_features_:
496+
raise ValueError("X has a different shape than during fitting.")
497+
493498
imputer_mask, feat_with_missing = self._get_missing_features_info(X)
494499

495500
if self.features == "train":
496501
features = np.setdiff1d(feat_with_missing,
497502
self.feat_with_missing_)
498-
if features.size:
503+
if features.size > 0:
499504
warnings.warn("The features %s have missing values "
500505
"in transform but have no missing values "
501506
"in fit " % features, RuntimeWarning,

0 commit comments

Comments
 (0)
0