@@ -974,23 +974,6 @@ def fit(self, X, y, **fit_params):
974974 pc_clf .fit (X , y , sample_weight = sample_weight )
975975
976976
977- def test_calibration_with_fit_params_inconsistent_length (data ):
978- """fit_params having different length than data should raise the
979- correct error message.
980- """
981- X , y = data
982- fit_params = {"a" : y [:5 ]}
983- clf = CheckingClassifier (expected_fit_params = fit_params )
984- pc_clf = CalibratedClassifierCV (clf )
985-
986- msg = (
987- r"Found input variables with inconsistent numbers of "
988- r"samples: \[" + str (N_SAMPLES ) + r", 5\]"
989- )
990- with pytest .raises (ValueError , match = msg ):
991- pc_clf .fit (X , y , ** fit_params )
992-
993-
994977@pytest .mark .parametrize ("method" , ["sigmoid" , "isotonic" ])
995978@pytest .mark .parametrize ("ensemble" , [True , False ])
996979def test_calibrated_classifier_cv_zeros_sample_weights_equivalence (method , ensemble ):
@@ -1054,3 +1037,17 @@ def test_calibrated_classifier_deprecation_base_estimator(data):
10541037 warn_msg = "`base_estimator` was renamed to `estimator`"
10551038 with pytest .warns (FutureWarning , match = warn_msg ):
10561039 calibrated_classifier .fit (* data )
1040+
1041+
1042+ def test_calibration_with_non_sample_aligned_fit_param (data ):
1043+ """Check that CalibratedClassifierCV does not enforce sample alignment
1044+ for fit parameters."""
1045+
1046+ class TestClassifier (LogisticRegression ):
1047+ def fit (self , X , y , sample_weight = None , fit_param = None ):
1048+ assert fit_param is not None
1049+ return super ().fit (X , y , sample_weight = sample_weight )
1050+
1051+ CalibratedClassifierCV (estimator = TestClassifier ()).fit (
1052+ * data , fit_param = np .ones (len (data [1 ]) + 1 )
1053+ )
0 commit comments