8000 TST be more specific in test_estimator_checks (#29834) · scikit-learn/scikit-learn@c24a3f9 · GitHub
[go: up one dir, main page]

Skip to content

Commit c24a3f9

Browse files
authored
TST be more specific in test_estimator_checks (#29834)
1 parent 7bcae6c commit c24a3f9

File tree

1 file changed

+107
-44
lines changed

1 file changed

+107
-44
lines changed

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 107 additions & 44 deletions
+
"OneClassSampleErrorClassifier", OneClassSampleErrorClassifier()
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,19 @@
4848
check_classifiers_multilabel_output_format_decision_function,
4949
check_classifiers_multilabel_output_format_predict,
5050
check_classifiers_multilabel_output_format_predict_proba,
51+
check_classifiers_one_label_sample_weights,
5152
check_dataframe_column_names_consistency,
5253
check_decision_proba_consistency,
54+
check_dict_unchanged,
55+
check_dont_overwrite_parameters,
5356
check_estimator,
5457
check_estimator_cloneable,
5558
check_estimator_repr,
59+
check_estimator_sparse_array,
60+
check_estimator_sparse_matrix,
5661
check_estimator_tags_renamed,
62+
check_estimators_nan_inf,
63+
check_estimators_overwrite_params,
5764
check_estimators_unfitted,
5865
check_fit_check_is_fitted,
5966
check_fit_score_takes_y,
@@ -62,8 +69,10 @@
6269
check_no_attributes_set_in_init,
6370
check_outlier_contamination,
6471
check_outlier_corruption,
72+
check_parameters_default_constructible,
6573
check_regressor_data_not_an_array,
6674
check_requires_y_none,
75+
check_sample_weights_pandas_series,
6776
check_set_params,
6877
set_random_state,
6978
)
@@ -573,40 +582,58 @@ def fit(self, X, y):
573582
check_fit_score_takes_y("test", TestEstimatorWithDeprecatedFitMethod())
574583

575584

576-
def test_check_estimator():
577-
# tests that the estimator actually fails on "bad" estimators.
578-
# not a complete test of all checks, which are very extensive.
579-
580-
# check that we have a set_params and can clone
585+
def test_check_estimator_with_class_removed():
586+
"""Test that passing a class instead of an instance fails."""
581587
msg = "Passing a class was deprecated"
582588
with raises(TypeError, match=msg):
583-
check_estimator(object)
589+
check_estimator(LogisticRegression)
590+
591+
592+
def test_mutable_default_params():
593+
"""Test that constructor cannot have mutable default parameters."""
584594
msg = (
585595
"Parameter 'p' of estimator 'HasMutableParameters' is of type "
586596
"object which is not allowed"
587597
)
588598
# check that the "default_constructible" test checks for mutable parameters
589-
check_estimator(HasImmutableParameters()) # should pass
599+
check_parameters_default_constructible(
600+
"Immutable", HasImmutableParameters()
601+
) # should pass
590602
with raises(AssertionError, match=msg):
591-
check_estimator(HasMutableParameters())
603+
check_parameters_default_constructible("Mutable", HasMutableParameters())
604+
605+
606+
def test_check_set_params():
607+
"""Check set_params doesn't fail and sets the right values."""
592608
# check that values returned by get_params match set_params
593609
msg = "get_params result does not match what was passed to set_params"
594610
with raises(AssertionError, match=msg):
595611
check_set_params("test", ModifiesValueInsteadOfRaisingError())
612+
596613
with warnings.catch_warnings(record=True) as records:
597614
check_set_params("test", RaisesErrorInSetParams())
598615
assert UserWarning in [rec.category for rec in records]
599616

600617
with raises(AssertionError, match=msg):
601-
check_estimator(ModifiesAnotherValue())
602-
# check that we have a fit method
603-
msg = "object has no attribute 'fit'"
604-
with raises(AttributeError, match=msg):
605-
check_estimator(BaseEstimator())
606-
# check that fit does input validation
607-
msg = "Did not raise"
618+
check_set_params("test", ModifiesAnotherValue())
619+
620+
621+
def test_check_estimators_nan_inf():
622+
# check that predict does input validation (doesn't accept dicts in input)
623+
msg = "Estimator NoCheckinPredict doesn't check for NaN and inf in predict"
608624
with raises(AssertionError, match=msg):
609-
check_estimator(BaseBadClassifier())
625+
check_estimators_nan_inf("NoCheckinPredict", NoCheckinPredict())
626+
627+
628+
def test_check_dict_unchanged():
629+
# check that estimator state does not change
630+
# at transform/predict/predict_proba time
631+
msg = "Estimator changes __dict__ during predict"
632+
with raises(AssertionError, match=msg):
633+
check_dict_unchanged("test", ChangesDict())
634+
635+
636+
def test_check_sample_weights_pandas_series():
610637
# check that sample_weights in fit accepts pandas.Series type
611638
try:
612639
from pandas import Series # noqa
@@ -616,27 +643,28 @@ def test_check_estimator():
616643
"'sample_weight' parameter is of type pandas.Series"
617644
)
618645
with raises(ValueError, match=msg):
619-
check_estimator(NoSampleWeightPandasSeriesType())
646+
check_sample_weights_pandas_series(
647+
"NoSampleWeightPandasSeriesType", NoSampleWeightPandasSeriesType()
648+
)
620649
except ImportError:
621650
pass
622-
# check that predict does input validation (doesn't accept dicts in input)
623-
msg = "Estimator NoCheckinPredict doesn't check for NaN and inf in predict"
624-
with raises(AssertionError, match=msg):
625-
check_estimator(NoCheckinPredict())
626-
# check that estimator state does not change
627-
# at transform/predict/predict_proba time
628-
msg = "Estimator changes __dict__ during predict"
629-
with raises(AssertionError, match=msg):
630-
check_estimator(ChangesDict())
651+
652+
653+
def test_check_estimators_overwrite_params():
631654
# check that `fit` only changes attributes that
632655
# are private (start with an _ or end with a _).
633656
msg = (
634657
"Estimator ChangesWrongAttribute should not change or mutate "
635658
"the parameter wrong_attribute from 0 to 1 during fit."
636659
)
637660
with raises(AssertionError, match=msg):
638-
check_estimator(ChangesWrongAttribute())
639-
check_estimator(ChangesUnderscoreAttribute())
661+
check_estimators_overwrite_params(
662+
"ChangesWrongAttribute", ChangesWrongAttribute()
663+
)
664+
check_estimators_overwrite_params("test", ChangesUnderscoreAttribute())
665+
666+
667+
def test_check_dont_overwrite_parameters():
640668
# check that `fit` doesn't add any public attribute
641669
msg = (
642670
r"Estimator adds public attribute\(s\) during the fit method."
@@ -645,7 +673,10 @@ def test_check_estimator():
645673
" with _ but wrong_attribute added"
646674
)
647675
with raises(AssertionError, match=msg):
648-
check_estimator(SetsWrongAttribute())
676+
check_dont_overwrite_parameters("test", SetsWrongAttribute())
677+
678+
679+
def test_check_methods_sample_order_invariance():
649680
# check for sample order invariance
650681
name = NotInvariantSampleOrder.__name__
651682
method = "predict"
@@ -654,25 +685,53 @@ def test_check_estimator():
654685
"with different sample order."
655686
).format(method=method, name=name)
656687
with raises(AssertionError, match=msg):
657-
check_estimator(NotInvariantSampleOrder())
688+
check_methods_sample_order_invariance(
689+
"NotInvariantSampleOrder", NotInvariantSampleOrder()
690+
)
691+
692+
693+
def test_check_methods_subset_invariance():
658694
# check for invariant method
659695
name = NotInvariantPredict.__name__
660696
method = "predict"
661697
msg = ("{method} of {name} is not invariant when applied to a subset.").format(
662698
method=method, name=name
663699
)
664700
with raises(AssertionError, match=msg):
665-
check_estimator(NotInvariantPredict())
701+
check_methods_subset_invariance("NotInvariantPredict", NotInvariantPredict())
702+
703+
704+
def test_check_estimator_sparse_data():
666705
# check for sparse data input handling
667706
name = NoSparseClassifier.__name__
668707
msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name
669708
with raises(AssertionError, match=msg):
670-
check_estimator(NoSparseClassifier("sparse_matrix"))
709+
check_estimator_sparse_matrix(name, NoSparseClassifier("sparse_matrix"))
671710

672711
if SPARRAY_PRESENT:
673712
with raises(AssertionError, match=msg):
674-
check_estimator(NoSparseClassifier("sparse_array"))
713+
check_estimator_sparse_array(name, NoSparseClassifier("sparse_array"))
675714

715+
# Large indices test on bad estimator
716+
msg = (
717+
"Estimator LargeSparseNotSupportedClassifier doesn't seem to "
718+
r"support \S{3}_64 matrix, and is not failing gracefully.*"
719+
)
720+
with raises(AssertionError, match=msg):
721+
check_estimator_sparse_matrix(
722+
"LargeSparseNotSupportedClassifier",
723+
LargeSparseNotSupportedClassifier("sparse_matrix"),
724+
)
725+
726+
if SPARRAY_PRESENT:
727+
with raises(AssertionError, match=msg):
728+
check_estimator_sparse_array(
729+
"LargeSparseNotSupportedClassifier",
730+
LargeSparseNotSupportedClassifier("sparse_array"),
731+
)
732+
733+
734+
def test_check_classifiers_one_label_sample_weights():
676735
# check for classifiers reducing to less than two classes via sample weights
677736
name = OneClassSampleErrorClassifier.__name__
678737
msg = (
@@ -681,19 +740,23 @@ def test_check_estimator():
681740
"'class'."
682741
)
683742
with raises(AssertionError, match=msg):
684-
check_estimator(OneClassSampleErrorClassifier())
743+
check_classifiers_one_label_sample_weights(
744
745+
)
685746

686-
# Large indices test on bad estimator
687-
msg = (
688-
"Estimator LargeSparseNotSupportedClassifier doesn't seem to "
689-
r"support \S{3}_64 matrix, and is not failing gracefully.*"
690-
)
691-
with raises(AssertionError, match=msg):
692-
check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix"))
693747

694-
if SPARRAY_PRESENT:
695-
with raises(AssertionError, match=msg):
696-
check_estimator(LargeSparseNotSupportedClassifier("sparse_array"))
748+
def test_check_estimator():
749+
# tests that the estimator actually fails on "bad" estimators.
750+
# not a complete test of all checks, which are very extensive.
751+
752+
# check that we have a fit method
753+
msg = "object has no attribute 'fit'"
754+
with raises(AttributeError, match=msg):
755+
check_estimator(BaseEstimator())
756+
# check that fit does input validation
757+
msg = "Did not raise"
758+
with raises(AssertionError, match=msg):
759+
check_estimator(BaseBadClassifier())
697760

698761
# does error on binary_only untagged estimator
699762
msg = "Only 2 classes are supported"

0 commit comments

Comments
 (0)
0