8000 TST add a few more tests to API checks (#29832) · scikit-learn/scikit-learn@cc5372c · GitHub
[go: up one dir, main page]

Skip to content

Commit cc5372c

Browse files
authored
TST add a few more tests to API checks (#29832)
1 parent c71340f commit cc5372c

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,17 @@
8282

8383

8484
def _yield_api_checks(estimator):
85+
tags = get_tags(estimator)
86+
yield check_estimator_cloneable
8587
yield check_estimator_repr
8688
yield check_no_attributes_set_in_init
8789
yield check_fit_score_takes_y
8890
yield check_estimators_overwrite_params
91+
yield check_dont_overwrite_parameters
92+
yield check_estimators_fit_returns_self
93+
yield check_readonly_memmap_input
94+
if tags.requires_fit:
95+
yield check_estimators_unfitted
8996
yield check_do_not_raise_errors_in_init_or_set_params
9097

9198

@@ -104,8 +111,6 @@ def _yield_checks(estimator):
104111
yield check_sample_weights_not_overwritten
105112
yield partial(check_sample_weights_invariance, kind="ones")
106113
yield partial(check_sample_weights_invariance, kind="zeros")
107-
yield check_estimators_fit_returns_self
108-
yield partial(check_estimators_fit_returns_self, readonly_memmap=True)
109114

110115
# Check that all estimator yield informative messages when
111116
# trained on empty datasets
@@ -173,8 +178,6 @@ def _yield_classifier_checks(classifier):
173178
yield check_supervised_y_no_nan
174179
if tags.target_tags.single_output:
175180
yield check_supervised_y_2d
176-
if tags.requires_fit:
177-
yield check_estimators_unfitted
178181
if "class_weight" in classifier.get_params().keys():
179182
yield check_class_weight_classifiers
180183

@@ -247,8 +250,6 @@ def _yield_regressor_checks(regressor):
247250
if name != "CCA":
248251
# check that the regressor handles int input
249252
yield check_regressors_int
250-
if tags.requires_fit:
251-
yield check_estimators_unfitted
252253
yield check_non_transformer_estimators_n_iter
253254

254255

@@ -311,9 +312,6 @@ def _yield_outliers_checks(estimator):
311312
yield partial(check_outliers_train, readonly_memmap=True)
312313
# test outlier detectors can handle non-array data
313314
yield check_classifier_data_not_an_array
314-
# test if NotFittedError is raised
315-
if get_tags(estimator).requires_fit:
316-
yield check_estimators_unfitted
317315
yield check_non_transformer_estimators_n_iter
318316

319317

@@ -381,7 +379,6 @@ def _yield_all_checks(estimator, legacy: bool):
381379
yield check_get_params_invariance
382380
yield check_set_params
383381
yield check_dict_unchanged
384-
yield check_dont_overwrite_parameters
385382
yield check_fit_idempotent
386383
yield check_fit_check_is_fitted
387384
if not tags.no_validation:
@@ -2724,18 +2721,34 @@ def check_get_feature_names_out_error(name, estimator_orig):
27242721

27252722

27262723
@ignore_warnings(category=FutureWarning)
2727-
def check_estimators_fit_returns_self(name, estimator_orig, readonly_memmap=False):
2724+
def check_estimators_fit_returns_self(name, estimator_orig):
27282725
"""Check if self is returned when calling fit."""
27292726
X, y = make_blobs(random_state=0, n_samples=21)
27302727
X = _enforce_estimator_tags_X(estimator_orig, X)
27312728

27322729
estimator = clone(estimator_orig)
27332730
y = _enforce_estimator_tags_y(estimator, y)
27342731

2735-
if readonly_memmap:
2736-
X, y = create_memmap_backed_data([X, y])
2732+
set_random_state(estimator)
2733+
assert estimator.fit(X, y) is estimator
2734+
2735+
2736+
@ignore_warnings(category=FutureWarning)
2737+
def check_readonly_memmap_input(name, estimator_orig):
2738+
"""Check that the estimator can handle readonly memmap backed data.
2739+
2740+
This is particularly needed to support joblib parallelisation.
2741+
"""
2742+
X, y = make_blobs(random_state=0, n_samples=21)
2743+
X = _enforce_estimator_tags_X(estimator_orig, X)
2744+
2745+
estimator = clone(estimator_orig)
2746+
y = _enforce_estimator_tags_y(estimator, y)
2747+
2748+
X, y = create_memmap_backed_data([X, y])
27372749

27382750
set_random_state(estimator)
2751+
# This should not raise an error and should return self
27392752
assert estimator.fit(X, y) is estimator
27402753

27412754

@@ -2745,6 +2758,15 @@ def check_estimators_unfitted(name, estimator_orig):
27452758
27462759
Unfitted estimators should raise a NotFittedError.
27472760
"""
2761+
err_msg = (
2762+
"Estimator should raise a NotFittedError when calling `{method}` before fit. "
2763+
"Either call `check_is_fitted(self)` at the beginning of `{method}` or "
2764+
"set `tags.requires_fit=False` on estimator tags to disable this check.\n"
2765+
"- `check_is_fitted`: https://scikit-learn.org/dev/modules/generated/sklearn."
2766+
"utils.validation.check_is_fitted.html\n"
2767+
"- Estimator Tags: https://scikit-learn.org/dev/developers/develop."
2768+
"html#estimator-tags"
2769+
)
27482770
# Common test for Regressors, Classifiers and Outlier detection estimators
27492771
X, y = _regression_dataset()
27502772

@@ -2756,7 +2778,7 @@ def check_estimators_unfitted(name, estimator_orig):
27562778
"predict_log_proba",
27572779
):
27582780
if hasattr(estimator, method):
2759-
with raises(NotFittedError):
2781+
with raises(NotFittedError, err_msg=err_msg.format(method=method)):
27602782
getattr(estimator, method)(X)
27612783

27622784

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -753,10 +753,6 @@ def test_check_estimator():
753753
msg = "object has no attribute 'fit'"
754754
with raises(AttributeError, match=msg):
755755
check_estimator(BaseEstimator())
756-
# check that fit does input validation
757-
msg = "Did not raise"
758-
with raises(AssertionError, match=msg):
759-
check_estimator(BaseBadClassifier())
760756

761757
# does error on binary_only untagged estimator
762758
msg = "Only 2 classes are supported"
@@ -836,7 +832,7 @@ def test_check_estimator_clones():
836832
def test_check_estimators_unfitted():
837833
# check that a ValueError/AttributeError is raised when calling predict
838834
# on an unfitted estimator
839-
msg = "Did not raise"
835+
msg = "Estimator should raise a NotFittedError when calling"
840836
with raises(AssertionError, match=msg):
841837
check_estimators_unfitted("estimator", NoSparseClassifier())
842838

0 commit comments

Comments
 (0)
0