@@ -974,23 +974,6 @@ def fit(self, X, y, **fit_params):
974
974
pc_clf .fit (X , y , sample_weight = sample_weight )
975
975
976
976
977
- def test_calibration_with_fit_params_inconsistent_length (data ):
978
- """fit_params having different length than data should raise the
979
- correct error message.
980
- """
981
- X , y = data
982
- fit_params = {"a" : y [:5 ]}
983
- clf = CheckingClassifier (expected_fit_params = fit_params )
984
- pc_clf = CalibratedClassifierCV (clf )
985
-
986
- msg = (
987
- r"Found input variables with inconsistent numbers of "
988
- r"samples: \[" + str (N_SAMPLES ) + r", 5\]"
989
- )
990
- with pytest .raises (ValueError , match = msg ):
991
- pc_clf .fit (X , y , ** fit_params )
992
-
993
-
994
977
@pytest .mark .parametrize ("method" , ["sigmoid" , "isotonic" ])
995
978
@pytest .mark .parametrize ("ensemble" , [True , False ])
996
979
def test_calibrated_classifier_cv_zeros_sample_weights_equivalence (method , ensemble ):
@@ -1054,3 +1037,17 @@ def test_calibrated_classifier_deprecation_base_estimator(data):
1054
1037
warn_msg = "`base_estimator` was renamed to `estimator`"
1055
1038
with pytest .warns (FutureWarning , match = warn_msg ):
1056
1039
calibrated_classifier .fit (* data )
1040
+
1041
+
1042
+ def test_calibration_with_non_sample_aligned_fit_param (data ):
1043
+ """Check that CalibratedClassifierCV does not enforce sample alignment
1044
+ for fit parameters."""
1045
+
1046
+ class TestClassifier (LogisticRegression ):
1047
+ def fit (self , X , y , sample_weight = None , fit_param = None ):
1048
+ assert fit_param is not None
1049
+ return super ().fit (X , y , sample_weight = sample_weight )
1050
+
1051
+ CalibratedClassifierCV (estimator = TestClassifier ()).fit (
1052
+ * data , fit_param = np .ones (len (data [1 ]) + 1 )
1053
+ )
0 commit comments