4848 check_classifiers_multilabel_output_format_decision_function ,
4949 check_classifiers_multilabel_output_format_predict ,
5050 check_classifiers_multilabel_output_format_predict_proba ,
51+ check_classifiers_one_label_sample_weights ,
5152 check_dataframe_column_names_consistency ,
5253 check_decision_proba_consistency ,
54+ check_dict_unchanged ,
55+ check_dont_overwrite_parameters ,
5356 check_estimator ,
5457 check_estimator_cloneable ,
5558 check_estimator_repr ,
59+ check_estimator_sparse_array ,
60+ check_estimator_sparse_matrix ,
5661 check_estimator_tags_renamed ,
62+ check_estimators_nan_inf ,
63+ check_estimators_overwrite_params ,
5764 check_estimators_unfitted ,
5865 check_fit_check_is_fitted ,
5966 check_fit_score_takes_y ,
6269 check_no_attributes_set_in_init ,
6370 check_outlier_contamination ,
6471 check_outlier_corruption ,
72+ check_parameters_default_constructible ,
6573 check_regressor_data_not_an_array ,
6674 check_requires_y_none ,
75+ check_sample_weights_pandas_series ,
6776 check_set_params ,
6877 set_random_state ,
6978)
@@ -573,40 +582,58 @@ def fit(self, X, y):
573582 check_fit_score_takes_y ("test" , TestEstimatorWithDeprecatedFitMethod ())
574583
575584
576- def test_check_estimator ():
577- # tests that the estimator actually fails on "bad" estimators.
578- # not a complete test of all checks, which are very extensive.
579-
580- # check that we have a set_params and can clone
585+ def test_check_estimator_with_class_removed ():
586+ """Test that passing a class instead of an instance fails."""
581587 msg = "Passing a class was deprecated"
582588 with raises (TypeError , match = msg ):
583- check_estimator (object )
589+ check_estimator (LogisticRegression )
590+
591+
592+ def test_mutable_default_params ():
593+ """Test that constructor cannot have mutable default parameters."""
584594 msg = (
585595 "Parameter 'p' of estimator 'HasMutableParameters' is of type "
586596 "object which is not allowed"
587597 )
588598 # check that the "default_constructible" test checks for mutable parameters
589- check_estimator (HasImmutableParameters ()) # should pass
599+ check_parameters_default_constructible (
600+ "Immutable" , HasImmutableParameters ()
601+ ) # should pass
590602 with raises (AssertionError , match = msg ):
591- check_estimator (HasMutableParameters ())
603+ check_parameters_default_constructible ("Mutable" , HasMutableParameters ())
604+
605+
606+ def test_check_set_params ():
607+ """Check set_params doesn't fail and sets the right values."""
592608 # check that values returned by get_params match set_params
593609 msg = "get_params result does not match what was passed to set_params"
594610 with raises (AssertionError , match = msg ):
595611 check_set_params ("test" , ModifiesValueInsteadOfRaisingError ())
612+
596613 with warnings .catch_warnings (record = True ) as records :
597614 check_set_params ("test" , RaisesErrorInSetParams ())
598615 assert UserWarning in [rec .category for rec in records ]
599616
600617 with raises (AssertionError , match = msg ):
601- check_estimator (ModifiesAnotherValue ())
602- # check that we have a fit method
603- msg = "object has no attribute 'fit'"
604- with raises (AttributeError , match = msg ):
605- check_estimator (BaseEstimator ())
606- # check that fit does input validation
607- msg = "Did not raise"
618+ check_set_params ("test" , ModifiesAnotherValue ())
619+
620+
621+ def test_check_estimators_nan_inf ():
622+ # check that predict does input validation (doesn't accept dicts in input)
623+ msg = "Estimator NoCheckinPredict doesn't check for NaN and inf in predict"
608624 with raises (AssertionError , match = msg ):
609- check_estimator (BaseBadClassifier ())
625+ check_estimators_nan_inf ("NoCheckinPredict" , NoCheckinPredict ())
626+
627+
628+ def test_check_dict_unchanged ():
629+ # check that estimator state does not change
630+ # at transform/predict/predict_proba time
631+ msg = "Estimator changes __dict__ during predict"
632+ with raises (AssertionError , match = msg ):
633+ check_dict_unchanged ("test" , ChangesDict ())
634+
635+
636+ def test_check_sample_weights_pandas_series ():
610637 # check that sample_weights in fit accepts pandas.Series type
611638 try :
612639 from pandas import Series # noqa
@@ -616,27 +643,28 @@ def test_check_estimator():
616643 "'sample_weight' parameter is of type pandas.Series"
617644 )
618645 with raises (ValueError , match = msg ):
619- check_estimator (NoSampleWeightPandasSeriesType ())
646+ check_sample_weights_pandas_series (
647+ "NoSampleWeightPandasSeriesType" , NoSampleWeightPandasSeriesType ()
648+ )
620649 except ImportError :
621650 pass
622- # check that predict does input validation (doesn't accept dicts in input)
623- msg = "Estimator NoCheckinPredict doesn't check for NaN and inf in predict"
624- with raises (AssertionError , match = msg ):
625- check_estimator (NoCheckinPredict ())
626- # check that estimator state does not change
627- # at transform/predict/predict_proba time
628- msg = "Estimator changes __dict__ during predict"
629- with raises (AssertionError , match = msg ):
630- check_estimator (ChangesDict ())
651+
652+
653+ def test_check_estimators_overwrite_params ():
631654 # check that `fit` only changes attributes that
632655 # are private (start with an _ or end with a _).
633656 msg = (
634657 "Estimator ChangesWrongAttribute should not change or mutate "
635658 "the parameter wrong_attribute from 0 to 1 during fit."
636659 )
637660 with raises (AssertionError , match = msg ):
638- check_estimator (ChangesWrongAttribute ())
639- check_estimator (ChangesUnderscoreAttribute ())
661+ check_estimators_overwrite_params (
662+ "ChangesWrongAttribute" , ChangesWrongAttribute ()
663+ )
664+ check_estimators_overwrite_params ("test" , ChangesUnderscoreAttribute ())
665+
666+
667+ def test_check_dont_overwrite_parameters ():
640668 # check that `fit` doesn't add any public attribute
641669 msg = (
642670 r"Estimator adds public attribute\(s\) during the fit method."
@@ -645,7 +673,10 @@ def test_check_estimator():
645673 " with _ but wrong_attribute added"
646674 )
647675 with raises (AssertionError , match = msg ):
648- check_estimator (SetsWrongAttribute ())
676+ check_dont_overwrite_parameters ("test" , SetsWrongAttribute ())
677+
678+
679+ def test_check_methods_sample_order_invariance ():
649680 # check for sample order invariance
650681 name = NotInvariantSampleOrder .__name__
651682 method = "predict"
@@ -654,25 +685,53 @@ def test_check_estimator():
654685 "with different sample order."
655686 ).format (method = method , name = name )
656687 with raises (AssertionError , match = msg ):
657- check_estimator (NotInvariantSampleOrder ())
688+ check_methods_sample_order_invariance (
689+ "NotInvariantSampleOrder" , NotInvariantSampleOrder ()
690+ )
691+
692+
693+ def test_check_methods_subset_invariance ():
658694 # check for invariant method
659695 name = NotInvariantPredict .__name__
660696 method = "predict"
661697 msg = ("{method} of {name} is not invariant when applied to a subset." ).format (
662698 method = method , name = name
663699 )
664700 with raises (AssertionError , match = msg ):
665- check_estimator (NotInvariantPredict ())
701+ check_methods_subset_invariance ("NotInvariantPredict" , NotInvariantPredict ())
702+
703+
704+ def test_check_estimator_sparse_data ():
666705 # check for sparse data input handling
667706 name = NoSparseClassifier .__name__
668707 msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name
669708 with raises (AssertionError , match = msg ):
670- check_estimator ( NoSparseClassifier ("sparse_matrix" ))
709+ check_estimator_sparse_matrix ( name , NoSparseClassifier ("sparse_matrix" ))
671710
672711 if SPARRAY_PRESENT :
673712 with raises (AssertionError , match = msg ):
674- check_estimator ( NoSparseClassifier ("sparse_array" ))
713+ check_estimator_sparse_array ( name , NoSparseClassifier ("sparse_array" ))
675714
715+ # Large indices test on bad estimator
716+ msg = (
717+ "Estimator LargeSparseNotSupportedClassifier doesn't seem to "
718+ r"support \S{3}_64 matrix, and is not failing gracefully.*"
719+ )
720+ with raises (AssertionError , match = msg ):
721+ check_estimator_sparse_matrix (
722+ "LargeSparseNotSupportedClassifier" ,
723+ LargeSparseNotSupportedClassifier ("sparse_matrix" ),
724+ )
725+
726+ if SPARRAY_PRESENT :
727+ with raises (AssertionError , match = msg ):
728+ check_estimator_sparse_array (
729+ "LargeSparseNotSupportedClassifier" ,
730+ LargeSparseNotSupportedClassifier ("sparse_array" ),
731+ )
732+
733+
734+ def test_check_classifiers_one_label_sample_weights ():
676735 # check for classifiers reducing to less than two classes via sample weights
677736 name = OneClassSampleErrorClassifier .__name__
678737 msg = (
@@ -681,19 +740,23 @@ def test_check_estimator():
681740 "'class'."
682741 )
683742 with raises (AssertionError , match = msg):
684- check_estimator (OneClassSampleErrorClassifier ())
743+ check_classifiers_one_label_sample_weights (
744+ "OneClassSampleErrorClassifier" , OneClassSampleErrorClassifier ()
745+ )
685746
686- # Large indices test on bad estimator
687- msg = (
688- "Estimator LargeSparseNotSupportedClassifier doesn't seem to "
689- r"support \S{3}_64 matrix, and is not failing gracefully.*"
690- )
691- with raises (AssertionError , match = msg ):
692- check_estimator (LargeSparseNotSupportedClassifier ("sparse_matrix" ))
693747
694- if SPARRAY_PRESENT :
695- with raises (AssertionError , match = msg ):
696- check_estimator (LargeSparseNotSupportedClassifier ("sparse_array" ))
748+ def test_check_estimator ():
749+ # tests that the estimator actually fails on "bad" estimators.
750+ # not a complete test of all checks, which are very extensive.
751+
752+ # check that we have a fit method
753+ msg = "object has no attribute 'fit'"
754+ with raises (AttributeError , match = msg ):
755+ check_estimator (BaseEstimator ())
756+ # check that fit does input validation
757+ msg = "Did not raise"
758+ with raises (AssertionError , match = msg ):
759+ check_estimator (BaseBadClassifier ())
697760
698761 # does error on binary_only untagged estimator
699762 msg = "Only 2 classes are supported"
0 commit comments