8000 FIX make CalibratedClassifierCV not enforce sample alignment for fit_… · scikit-learn/scikit-learn@30bf6f3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 30bf6f3

Browse files
authored
FIX make CalibratedClassifierCV not enforce sample alignment for fit_params (#25805)
1 parent be978fb commit 30bf6f3

File tree

3 files changed

+20
-20
lines changed

3 files changed

+20
-20
lines changed

doc/whats_new/v1.3.rst

+6
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ Changelog
143143
- |Feature| A `__sklearn_clone__` protocol is now available to override the
144144
default behavior of :func:`base.clone`. :pr:`24568` by `Thomas Fan`_.
145145

146+
:mod:`sklearn.calibration`
147+
..........................
148+
149+
- |Fix| :class:`calibration.CalibratedClassifierCV` now does not enforce sample
150+
alignment on `fit_params`. :pr:`25805` by `Adrin Jalali`_.
151+
146152
:mod:`sklearn.cluster`
147153
......................
148154

sklearn/calibration.py

-3
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,6 @@ def fit(self, X, y, sample_weight=None, **fit_params):
308308
if sample_weight is not None:
309309
sample_weight = _check_sample_weight(sample_weight, X)
310310

311-
for sample_aligned_params in fit_params.values():
312-
check_consistent_length(y, sample_aligned_params)
313-
314311
# TODO(1.4): Remove when base_estimator is removed
315312
if self.base_estimator != "deprecated":
316313
if self.estimator is not None:

sklearn/tests/test_calibration.py

+14-17
Original file line numberDiff line numberDiff line change
@@ -974,23 +974,6 @@ def fit(self, X, y, **fit_params):
974974
pc_clf.fit(X, y, sample_weight=sample_weight)
975975

976976

977-
def test_calibration_with_fit_params_inconsistent_length(data):
978-
"""fit_params having different length than data should raise the
979-
correct error message.
980-
"""
981-
X, y = data
982-
fit_params = {"a": y[:5]}
983-
clf = CheckingClassifier(expected_fit_params=fit_params)
984-
pc_clf = CalibratedClassifierCV(clf)
985-
986-
msg = (
987-
r"Found input variables with inconsistent numbers of "
988-
r"samples: \[" + str(N_SAMPLES) + r", 5\]"
989-
)
990-
with pytest.raises(ValueError, match=msg):
991-
pc_clf.fit(X, y, **fit_params)
992-
993-
994977
@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
995978
@pytest.mark.parametrize("ensemble", [True, False])
996979
def test_calibrated_classifier_cv_zeros_sample_weights_equivalence(method, ensemble):
@@ -1054,3 +1037,17 @@ def test_calibrated_classifier_deprecation_base_estimator(data):
10541037
warn_msg = "`base_estimator` was renamed to `estimator`"
10551038
with pytest.warns(FutureWarning, match=warn_msg):
10561039
calibrated_classifier.fit(*data)
1040+
1041+
1042+
def test_calibration_with_non_sample_aligned_fit_param(data):
1043+
"""Check that CalibratedClassifierCV does not enforce sample alignment
1044+
for fit parameters."""
1045+
1046+
class TestClassifier(LogisticRegression):
1047+
def fit(self, X, y, sample_weight=None, fit_param=None):
1048+
assert fit_param is not None
1049+
return super().fit(X, y, sample_weight=sample_weight)
1050+
1051+
CalibratedClassifierCV(estimator=TestClassifier()).fit(
1052+
*data, fit_param=np.ones(len(data[1]) + 1)
1053+
)

0 commit comments

Comments
 (0)
0