48
48
check_classifiers_multilabel_output_format_decision_function ,
49
49
check_classifiers_multilabel_output_format_predict ,
50
50
check_classifiers_multilabel_output_format_predict_proba ,
51
+ check_classifiers_one_label_sample_weights ,
51
52
check_dataframe_column_names_consistency ,
52
53
check_decision_proba_consistency ,
54
+ check_dict_unchanged ,
55
+ check_dont_overwrite_parameters ,
53
56
check_estimator ,
54
57
check_estimator_cloneable ,
55
58
check_estimator_repr ,
59
+ check_estimator_sparse_array ,
60
+ check_estimator_sparse_matrix ,
56
61
check_estimator_tags_renamed ,
62
+ check_estimators_nan_inf ,
63
+ check_estimators_overwrite_params ,
57
64
check_estimators_unfitted ,
58
65
check_fit_check_is_fitted ,
59
66
check_fit_score_takes_y ,
62
69
check_no_attributes_set_in_init ,
63
70
check_outlier_contamination ,
64
71
check_outlier_corruption ,
72
+ check_parameters_default_constructible ,
65
73
check_regressor_data_not_an_array ,
66
74
check_requires_y_none ,
75
+ check_sample_weights_pandas_series ,
67
76
check_set_params ,
68
77
set_random_state ,
69
78
)
@@ -573,40 +582,58 @@ def fit(self, X, y):
573
582
check_fit_score_takes_y ("test" , TestEstimatorWithDeprecatedFitMethod ())
574
583
575
584
576
- def test_check_estimator ():
577
- # tests that the estimator actually fails on "bad" estimators.
578
- # not a complete test of all checks, which are very extensive.
579
-
580
- # check that we have a set_params and can clone
585
+ def test_check_estimator_with_class_removed ():
586
+ """Test that passing a class instead of an instance fails."""
581
587
msg = "Passing a class was deprecated"
582
588
with raises (TypeError , match = msg ):
583
- check_estimator (object )
589
+ check_estimator (LogisticRegression )
590
+
591
+
592
+ def test_mutable_default_params ():
593
+ """Test that constructor cannot have mutable default parameters."""
584
594
msg = (
585
595
"Parameter 'p' of estimator 'HasMutableParameters' is of type "
586
596
"object which is not allowed"
587
597
)
588
598
# check that the "default_constructible" test checks for mutable parameters
589
- check_estimator (HasImmutableParameters ()) # should pass
599
+ check_parameters_default_constructible (
600
+ "Immutable" , HasImmutableParameters ()
601
+ ) # should pass
590
602
with raises (AssertionError , match = msg ):
591
- check_estimator (HasMutableParameters ())
603
+ check_parameters_default_constructible ("Mutable" , HasMutableParameters ())
604
+
605
+
606
+ def test_check_set_params ():
607
+ """Check set_params doesn't fail and sets the right values."""
592
608
# check that values returned by get_params match set_params
593
609
msg = "get_params result does not match what was passed to set_params"
594
610
with raises (AssertionError , match = msg ):
595
611
check_set_params ("test" , ModifiesValueInsteadOfRaisingError ())
612
+
596
613
with warnings .catch_warnings (record = True ) as records :
597
614
check_set_params ("test" , RaisesErrorInSetParams ())
598
615
assert UserWarning in [rec .category for rec in records ]
599
616
600
617
with raises (AssertionError , match = msg ):
601
- check_estimator (ModifiesAnotherValue ())
602
- # check that we have a fit method
603
- msg = "object has no attribute 'fit'"
604
- with raises (AttributeError , match = msg ):
605
- check_estimator (BaseEstimator ())
606
- # check that fit does input validation
607
- msg = "Did not raise"
618
+ check_set_params ("test" , ModifiesAnotherValue ())
619
+
620
+
621
+ def test_check_estimators_nan_inf ():
622
+ # check that predict does input validation (doesn't accept dicts in input)
623
+ msg = "Estimator NoCheckinPredict doesn't check for NaN and inf in predict"
608
624
with raises (AssertionError , match = msg ):
609
- check_estimator (BaseBadClassifier ())
625
+ check_estimators_nan_inf ("NoCheckinPredict" , NoCheckinPredict ())
626
+
627
+
628
+ def test_check_dict_unchanged ():
629
+ # check that estimator state does not change
630
+ # at transform/predict/predict_proba time
631
+ msg = "Estimator changes __dict__ during predict"
632
+ with raises (AssertionError , match = msg ):
633
+ check_dict_unchanged ("test" , ChangesDict ())
634
+
635
+
636
+ def test_check_sample_weights_pandas_series ():
610
637
# check that sample_weights in fit accepts pandas.Series type
611
638
try :
612
639
from pandas import Series # noqa
@@ -616,27 +643,28 @@ def test_check_estimator():
616
643
"'sample_weight' parameter is of type pandas.Series"
617
644
)
618
645
with raises (ValueError , match = msg ):
619
- check_estimator (NoSampleWeightPandasSeriesType ())
646
+ check_sample_weights_pandas_series (
647
+ "NoSampleWeightPandasSeriesType" , NoSampleWeightPandasSeriesType ()
648
+ )
620
649
except ImportError :
621
650
pass
622
- # check that predict does input validation (doesn't accept dicts in input)
623
- msg = "Estimator NoCheckinPredict doesn't check for NaN and inf in predict"
624
- with raises (AssertionError , match = msg ):
625
- check_estimator (NoCheckinPredict ())
626
- # check that estimator state does not change
627
- # at transform/predict/predict_proba time
628
- msg = "Estimator changes __dict__ during predict"
629
- with raises (AssertionError , match = msg ):
630
- check_estimator (ChangesDict ())
651
+
652
+
653
+ def test_check_estimators_overwrite_params ():
631
654
# check that `fit` only changes attributes that
632
655
# are private (start with an _ or end with a _).
633
656
msg = (
634
657
"Estimator ChangesWrongAttribute should not change or mutate "
635
658
"the parameter wrong_attribute from 0 to 1 during fit."
636
659
)
637
660
with raises (AssertionError , match = msg ):
638
- check_estimator (ChangesWrongAttribute ())
639
- check_estimator (ChangesUnderscoreAttribute ())
661
+ check_estimators_overwrite_params (
662
+ "ChangesWrongAttribute" , ChangesWrongAttribute ()
663
+ )
664
+ check_estimators_overwrite_params ("test" , ChangesUnderscoreAttribute ())
665
+
666
+
667
+ def test_check_dont_overwrite_parameters ():
640
668
# check that `fit` doesn't add any public attribute
641
669
msg = (
642
670
r"Estimator adds public attribute\(s\) during the fit method."
@@ -645,7 +673,10 @@ def test_check_estimator():
645
673
" with _ but wrong_attribute added"
646
674
)
647
675
with raises (AssertionError , match = msg ):
648
- check_estimator (SetsWrongAttribute ())
676
+ check_dont_overwrite_parameters ("test" , SetsWrongAttribute ())
677
+
678
+
679
+ def test_check_methods_sample_order_invariance ():
649
680
# check for sample order invariance
650
681
name = NotInvariantSampleOrder .__name__
651
682
method = "predict"
@@ -654,25 +685,53 @@ def test_check_estimator():
654
685
"with different sample order."
655
686
).format (method = method , name = name )
656
687
with raises (AssertionError , match = msg ):
657
- check_estimator (NotInvariantSampleOrder ())
688
+ check_methods_sample_order_invariance (
689
+ "NotInvariantSampleOrder" , NotInvariantSampleOrder ()
690
+ )
691
+
692
+
693
+ def test_check_methods_subset_invariance ():
658
694
# check for invariant method
659
695
name = NotInvariantPredict .__name__
660
696
method = "predict"
661
697
msg = ("{method} of {name} is not invariant when applied to a subset." ).format (
662
698
method = method , name = name
663
699
)
664
700
with raises (AssertionError , match = msg ):
665
- check_estimator (NotInvariantPredict ())
701
+ check_methods_subset_invariance ("NotInvariantPredict" , NotInvariantPredict ())
702
+
703
+
704
+ def test_check_estimator_sparse_data ():
666
705
# check for sparse data input handling
667
706
name = NoSparseClassifier .__name__
668
707
msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name
669
708
with raises (AssertionError , match = msg ):
670
- check_estimator ( NoSparseClassifier ("sparse_matrix" ))
709
+ check_estimator_sparse_matrix ( name , NoSparseClassifier ("sparse_matrix" ))
671
710
672
711
if SPARRAY_PRESENT :
673
712
with raises (AssertionError , match = msg ):
674
- check_estimator ( NoSparseClassifier ("sparse_array" ))
713
+ check_estimator_sparse_array ( name , NoSparseClassifier ("sparse_array" ))
675
714
715
+ # Large indices test on bad estimator
716
+ msg = (
717
+ "Estimator LargeSparseNotSupportedClassifier doesn't seem to "
718
+ r"support \S{3}_64 matrix, and is not failing gracefully.*"
719
+ )
720
+ with raises (AssertionError , match = msg ):
721
+ check_estimator_sparse_matrix (
722
+ "LargeSparseNotSupportedClassifier" ,
723
+ LargeSparseNotSupportedClassifier ("sparse_matrix" ),
724
+ )
725
+
726
+ if SPARRAY_PRESENT :
727
+ with raises (AssertionError , match = msg ):
728
+ check_estimator_sparse_array (
729
+ "LargeSparseNotSupportedClassifier" ,
730
+ LargeSparseNotSupportedClassifier ("sparse_array" ),
731
+ )
732
+
733
+
734
+ def test_check_classifiers_one_label_sample_weights ():
676
735
# check for classifiers reducing to less than two classes via sample weights
677
736
name = OneClassSampleErrorClassifier .__name__
678
737
msg = (
@@ -681,19 +740,23 @@ def test_check_estimator():
681
740
"'class'."
682
741
)
683
742
with raises (AssertionError , match = msg ):
684
- check_estimator (OneClassSampleErrorClassifier ())
743
+ check_classifiers_one_label_sample_weights (
744
+ "OneClassSampleErrorClassifier" , OneClassSampleErrorClassifier ()
745
+ )
685
746
686
- # Large indices test on bad estimator
687
- msg = (
688
- "Estimator LargeSparseNotSupportedClassifier doesn't seem to "
689
- r"support \S{3}_64 matrix, and is not failing gracefully.*"
690
- )
691
- with raises (AssertionError , match = msg ):
692
- check_estimator (LargeSparseNotSupportedClassifier ("sparse_matrix" ))
693
747
694
- if SPARRAY_PRESENT :
695
- with raises (AssertionError , match = msg ):
696
- check_estimator (LargeSparseNotSupportedClassifier ("sparse_array" ))
748
+ def test_check_estimator ():
749
+ # tests that the estimator actually fails on "bad" estimators.
750
+ # not a complete test of all checks, which are very extensive.
751
+
752
+ # check that we have a fit method
753
+ msg = "object has no attribute 'fit'"
754
+ with raises (AttributeError , match = msg ):
755
+ check_estimator (BaseEstimator ())
756
+ # check that fit does input validation
757
+ msg = "Did not raise"
758
+ with raises (AssertionError , match = msg ):
759
+ check_estimator (BaseBadClassifier ())
697
760
698
761
# does error on binary_only untagged estimator
699
762
msg = "Only 2 classes are supported"
0 commit comments