10000 TST Fixes test (Will fail now on CI) · thomasjpfan/scikit-learn@3c1d57f · GitHub
[go: up one dir, main page]

Skip to content

Commit 3c1d57f

Browse files
committed
TST Fixes test (Will fail now on CI)
1 parent 4e3e8e8 commit 3c1d57f

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

sklearn/tests/test_common.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,10 +337,8 @@ def _estimators_that_predict_in_fit():
337337
est_params = set(estimator.get_params())
338338
if "oob_score" in est_params:
339339
yield estimator.set_params(oob_score=True, bootstrap=True)
340-
elif "n_iter_no_change" in est_params:
341-
yield estimator.set_params(n_iter_no_change=1)
342340
elif "early_stopping" in est_params:
343-
yield estimator.set_params(early_stopping=True)
341+
yield estimator.set_params(early_stopping=True, n_iter_no_change=1)
344342

345343

346344
# NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator that

sklearn/utils/estimator_checks.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3862,6 +3862,12 @@ def check_dataframe_column_names_consistency(name, estimator_orig):
38623862
f"Feature names seen at fit time, yet now missing:\n- {min(names[3:])}\n",
38633863
),
38643864
]
3865+
params = {
3866+
key: value
3867+
for key, value in estimator.get_params().items()
3868+
if "early_stopping" in key
3869+
}
3870+
early_stopping_enabled = any(value is True for value in params.values())
38653871

38663872
for invalid_name, additional_message in invalid_names:
38673873
X_bad = pd.DataFrame(X, columns=invalid_name)
@@ -3885,7 +3891,8 @@ def check_dataframe_column_names_consistency(name, estimator_orig):
38853891
method(X_bad)
38863892

38873893
# partial_fit checks on second call
3888-
if not hasattr(estimator, "partial_fit"):
3894+
# Do not call partial fit if early_stopping is on
3895+
if not hasattr(estimator, "partial_fit") or early_stopping_enabled:
38893896
continue
38903897

38913898
estimator = clone(estimator_orig)

0 commit comments

Comments
 (0)
0