82
82
83
83
84
84
def _yield_api_checks (estimator ):
85
+ tags = get_tags (estimator )
86
+ yield check_estimator_cloneable
85
87
yield check_estimator_repr
86
88
yield check_no_attributes_set_in_init
87
89
yield check_fit_score_takes_y
88
90
yield check_estimators_overwrite_params
91
+ yield check_dont_overwrite_parameters
92
+ yield check_estimators_fit_returns_self
93
+ yield check_readonly_memmap_input
94
+ if tags .requires_fit :
95
+ yield check_estimators_unfitted
89
96
yield check_do_not_raise_errors_in_init_or_set_params
90
97
91
98
@@ -104,8 +111,6 @@ def _yield_checks(estimator):
104
111
yield check_sample_weights_not_overwritten
105
112
yield partial (check_sample_weights_invariance , kind = "ones" )
106
113
yield partial (check_sample_weights_invariance , kind = "zeros" )
107
- yield check_estimators_fit_returns_self
108
- yield partial (check_estimators_fit_returns_self , readonly_memmap = True )
109
114
110
115
# Check that all estimator yield informative messages when
111
116
# trained on empty datasets
@@ -173,8 +178,6 @@ def _yield_classifier_checks(classifier):
173
178
yield check_supervised_y_no_nan
174
179
if tags .target_tags .single_output :
175
180
yield check_supervised_y_2d
176
- if tags .requires_fit :
177
- yield check_estimators_unfitted
178
181
if "class_weight" in classifier .get_params ().keys ():
179
182
yield check_class_weight_classifiers
180
183
@@ -247,8 +250,6 @@ def _yield_regressor_checks(regressor):
247
250
if name != "CCA" :
248
251
# check that the regressor handles int input
249
252
yield check_regressors_int
250
- if tags .requires_fit :
251
- yield check_estimators_unfitted
252
253
yield check_non_transformer_estimators_n_iter
253
254
254
255
@@ -311,9 +312,6 @@ def _yield_outliers_checks(estimator):
311
312
yield partial (check_outliers_train , readonly_memmap = True )
312
313
# test outlier detectors can handle non-array data
313
314
yield check_classifier_data_not_an_array
314
- # test if NotFittedError is raised
315
- if get_tags (estimator ).requires_fit :
316
- yield check_estimators_unfitted
317
315
yield check_non_transformer_estimators_n_iter
318
316
319
317
@@ -381,7 +379,6 @@ def _yield_all_checks(estimator, legacy: bool):
381
379
yield check_get_params_invariance
382
380
yield check_set_params
383
381
yield check_dict_unchanged
384
- yield check_dont_overwrite_parameters
385
382
yield check_fit_idempotent
386
383
yield check_fit_check_is_fitted
387
384
if not tags .no_validation :
@@ -2724,18 +2721,34 @@ def check_get_feature_names_out_error(name, estimator_orig):
2724
2721
2725
2722
2726
2723
@ignore_warnings (category = FutureWarning )
2727
- def check_estimators_fit_returns_self (name , estimator_orig , readonly_memmap = False ):
2724
+ def check_estimators_fit_returns_self (name , estimator_orig ):
2728
2725
"""Check if self is returned when calling fit."""
2729
2726
X , y = make_blobs (random_state = 0 , n_samples = 21 )
2730
2727
X = _enforce_estimator_tags_X (estimator_orig , X )
2731
2728
2732
2729
estimator = clone (estimator_orig )
2733
2730
y = _enforce_estimator_tags_y (estimator , y )
2734
2731
2735
- if readonly_memmap :
2736
- X , y = create_memmap_backed_data ([X , y ])
2732
+ set_random_state (estimator )
2733
+ assert estimator .fit (X , y ) is estimator
2734
+
2735
+
2736
+ @ignore_warnings (category = FutureWarning )
2737
+ def check_readonly_memmap_input (name , estimator_orig ):
2738
+ """Check that the estimator can handle readonly memmap backed data.
2739
+
2740
+ This is particularly needed to support joblib parallelisation.
2741
+ """
2742
+ X , y = make_blobs (random_state = 0 , n_samples = 21 )
2743
+ X = _enforce_estimator_tags_X (estimator_orig , X )
2744
+
2745
+ estimator = clone (estimator_orig )
2746
+ y = _enforce_estimator_tags_y (estimator , y )
2747
+
2748
+ X , y = create_memmap_backed_data ([X , y ])
2737
2749
2738
2750
set_random_state (estimator )
2751
+ # This should not raise an error and should return self
2739
2752
assert estimator .fit (X , y ) is estimator
2740
2753
2741
2754
@@ -2745,6 +2758,15 @@ def check_estimators_unfitted(name, estimator_orig):
2745
2758
2746
2759
Unfitted estimators should raise a NotFittedError.
2747
2760
"""
2761
+ err_msg = (
2762
+ "Estimator should raise a NotFittedError when calling `{method}` before fit. "
2763
+ "Either call `check_is_fitted(self)` at the beginning of `{method}` or "
2764
+ "set `tags.requires_fit=False` on estimator tags to disable this check.\n "
2765
+ "- `check_is_fitted`: https://scikit-learn.org/dev/modules/generated/sklearn."
2766
+ "utils.validation.check_is_fitted.html\n "
2767
+ "- Estimator Tags: https://scikit-learn.org/dev/developers/develop."
2768
+ "html#estimator-tags"
2769
+ )
2748
2770
# Common test for Regressors, Classifiers and Outlier detection estimators
2749
2771
X , y = _regression_dataset ()
2750
2772
@@ -2756,7 +2778,7 @@ def check_estimators_unfitted(name, estimator_orig):
2756
2778
"predict_log_proba" ,
2757
2779
):
2758
2780
if hasattr (estimator , method ):
2759
- with raises (NotFittedError ):
2781
+ with raises (NotFittedError , err_msg = err_msg . format ( method = method ) ):
2760
2782
getattr (estimator , method )(X )
2761
2783
2762
2784
0 commit comments