8000 FIX Remove warnings when fitting a dataframe (#21578) · scikit-learn/scikit-learn@071f98f · GitHub
[go: up one dir, main page]

Skip to content

Commit 071f98f

Browse files
thomasjpfanogrisel
andauthored
FIX Remove warnings when fitting a dataframe (#21578)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 299364c commit 071f98f

File tree

8 files changed

+67
-11
lines changed

8 files changed

+67
-11
lines changed

doc/whats_new/v1.0.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@ Version 1.0.2
99

1010
**In Development**
1111

12+
- |Fix| :class:`cluster.Birch`,
13+
:class:`feature_selection.RFECV`, :class:`ensemble.RandomForestRegressor`,
14+
:class:`ensemble.RandomForestClassifier`,
15+
:class:`ensemble.GradientBoostingRegressor`, and
16+
:class:`ensemble.GradientBoostingClassifier` do not raise warning when fitted
17+
on a pandas DataFrame anymore. :pr:`21578` by `Thomas Fan`_.
18+
1219
Changelog
1320
---------
1421

sklearn/cluster/_birch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,10 @@ def predict(self, X):
676676
"""
677677
check_is_fitted(self)
678678
X = self._validate_data(X, accept_sparse="csr", reset=False)
679+
return self._predict(X)
680+
681+
def _predict(self, X):
682+
"""Predict data using the ``centroids_`` of subclusters."""
679683
kwargs = {"Y_norm_squared": self._subcluster_norms}
680684

681685
with config_context(assume_finite=True):
@@ -745,4 +749,4 @@ def _global_clustering(self, X=None):
745749
self.subcluster_labels_ = clusterer.fit_predict(self.subcluster_centers_)
746750

747751
if compute_labels:
748-
self.labels_ = self.predict(X)
752+
self.labels_ = self._predict(X)

sklearn/ensemble/_forest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,10 @@ def _compute_oob_predictions(self, X, y):
519519
oob_pred : ndarray of shape (n_samples, n_classes, n_outputs) or \
520520
(n_samples, 1, n_outputs)
521521
The OOB predictions.
522-
"""
523-
X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False)
522+
"""
523+
# Prediction requires X to be in CSR format
524+
if issparse(X):
525+
X = X.tocsr()
524526

525527
n_samples = y.shape[0]
526528
n_outputs = self.n_outputs_

sklearn/ensemble/_gb.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def _fit_stages(
643643
loss_history = np.full(self.n_iter_no_change, np.inf)
644644
# We create a generator to get the predictions for X_val after
645645
# the addition of each successive stage
646-
y_val_pred_iter = self._staged_raw_predict(X_val)
646+
y_val_pred_iter = self._staged_raw_predict(X_val, check_input=False)
647647

648648
# perform boosting iterations
649649
i = begin_at_stage
@@ -736,7 +736,7 @@ def _raw_predict(self, X):
736736
predict_stages(self.estimators_, X, self.learning_rate, raw_predictions)
737737
return raw_predictions
738738

739-
def _staged_raw_predict(self, X):
739+
def _staged_raw_predict(self, X, check_input=True):
740740
"""Compute raw predictions of ``X`` for each iteration.
741741
742742
This method allows monitoring (i.e. determine error on testing set)
@@ -749,6 +749,9 @@ def _staged_raw_predict(self, X):
749749
``dtype=np.float32`` and if a sparse matrix is provided
750750
to a sparse ``csr_matrix``.
751751
752+
check_input : bool, default=True
753+
If False, the input arrays X will not be checked.
754+
752755
Returns
753756
-------
754757
raw_predictions : generator of ndarray of shape (n_samples, k)
@@ -757,9 +760,10 @@ def _staged_raw_predict(self, X):
757760
Regression and binary classification are special cases with
758761
``k == 1``, otherwise ``k==n_classes``.
759762
"""
760-
X = self._validate_data(
761-
X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False
762-
)
763+
if check_input:
764+
X = self._validate_data(
765+
X, dtype=DTYPE, order="C", accept_sparse="csr", reset=False
766+
)
763767
raw_predictions = self._raw_predict_init(X)
764768
for i in range(self.estimators_.shape[0]):
765769
predict_stage(self.estimators_, i, X, self.learning_rate, raw_predictions)

sklearn/feature_selection/_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def transform(self, X):
8787
force_all_finite=not _safe_tags(self, key="allow_nan"),
8888
reset=False,
8989
)
90+
return self._transform(X)
91+
92+
def _transform(self, X):
93+
"""Reduce X to the selected features."""
9094
mask = self.get_support()
9195
if not mask.any():
9296
warn(

sklearn/feature_selection/_rfe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def fit(self, X, y, groups=None):
736736
self.n_features_ = rfe.n_features_
737737
self.ranking_ = rfe.ranking_
738738
self.estimator_ = clone(self.estimator)
739-
self.estimator_.fit(self.transform(X), y)
739+
self.estimator_.fit(self._transform(X), y)
740740

741741
# reverse to stay consistent with before
742742
scores_rev = scores[:, ::-1]

sklearn/tests/test_common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,24 @@ def test_check_n_features_in_after_fitting(estimator):
332332
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)
333333

334334

335+
def _estimators_that_predict_in_fit():
336+
for estimator in _tested_estimators():
337+
est_params = set(estimator.get_params())
338+
if "oob_score" in est_params:
339+
yield estimator.set_params(oob_score=True, bootstrap=True)
340+
elif "early_stopping" in est_params:
341+
est = estimator.set_params(early_stopping=True, n_iter_no_change=1)
342+
if est.__class__.__name__ in {"MLPClassifier", "MLPRegressor"}:
343+
# TODO: FIX MLP to not check validation set during MLP
344+
yield pytest.param(
345+
est, marks=pytest.mark.xfail(msg="MLP still validates in fit")
346+
)
347+
else:
348+
yield est
349+
elif "n_iter_no_change" in est_params:
350+
yield estimator.set_params(n_iter_no_change=1)
351+
352+
335353
# NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator that
336354
# delegates validation to a base estimator, the check is testing that the base estimator
337355
# is checking for column name consistency.
@@ -340,6 +358,7 @@ def test_check_n_features_in_after_fitting(estimator):
340358
_tested_estimators(),
341359
[make_pipeline(LogisticRegression(C=1))],
342360
list(_generate_search_cv_instances()),
361+
_estimators_that_predict_in_fit(),
343362
)
344363
)
345364

sklearn/utils/estimator_checks.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3791,7 +3791,16 @@ def check_dataframe_column_names_consistency(name, estimator_orig):
37913791
else:
37923792
y = rng.randint(low=0, high=2, size=n_samples)
37933793
y = _enforce_estimator_tags_y(estimator, y)
3794-
estimator.fit(X, y)
3794+
3795+
# Check that calling `fit` does not raise any warnings about feature names.
3796+
with warnings.catch_warnings():
3797+
warnings.filterwarnings(
3798+
"error",
3799+
message="X does not have valid feature names",
3800+
category=UserWarning,
3801+
module="sklearn",
3802+
)
3803+
estimator.fit(X, y)
37953804

37963805
if not hasattr(estimator, "feature_names_in_"):
37973806
raise ValueError(
@@ -3853,6 +3862,12 @@ def check_dataframe_column_names_consistency(name, estimator_orig):
38533862
f"Feature names seen at fit time, yet now missing:\n- {min(names[3:])}\n",
38543863
),
38553864
]
3865+
params = {
3866+
key: value
3867+
for key, value in estimator.get_params().items()
3868+
if "early_stopping" in key
3869+
}
3870+
early_stopping_enabled = any(value is True for value in params.values())
38563871

38573872
for invalid_name, additional_message in invalid_names:
38583873
X_bad = pd.DataFrame(X, columns=invalid_name)
@@ -3876,7 +3891,8 @@ def check_dataframe_column_names_consistency(name, estimator_orig):
38763891
method(X_bad)
38773892

38783893
# partial_fit checks on second call
3879-
if not hasattr(estimator, "partial_fit"):
3894+
# Do not call partial fit if early_stopping is on
3895+
if not hasattr(estimator, "partial_fit") or early_stopping_enabled:
38803896
continue
38813897

38823898
estimator = clone(estimator_orig)

0 commit comments

Comments
 (0)
0