8000 Make check_sample_weights_invariance cv-aware (#29796) · scikit-learn/scikit-learn@0b0b90b · GitHub
[go: up one dir, main page]

Skip to content

Commit 0b0b90b

Browse files
authored
Make check_sample_weights_invariance cv-aware (#29796)
1 parent 7baa11e commit 0b0b90b

File tree

5 files changed

+30
-75
lines changed

5 files changed

+30
-75
lines changed

sklearn/calibration.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -540,17 +540,6 @@ def get_metadata_routing(self):
540540
)
541541
return router
542542

543-
def __sklearn_tags__(self):
544-
tags = super().__sklearn_tags__()
545-
tags._xfail_checks = {
546-
"check_sample_weights_invariance": (
547-
"Due to the cross-validation and sample ordering, removing a sample"
548-
" is not strictly equal to putting is weight to zero. Specific unit"
549-
" tests are added for CalibratedClassifierCV specifically."
550-
),
551-
}
552-
return tags
553-
554543

555544
def _fit_classifier_calibrator_pair(
556545
estimator,

sklearn/linear_model/_logistic.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2269,15 +2269,6 @@ def get_metadata_routing(self):
22692269
)
22702270
return router
22712271

2272-
def __sklearn_tags__(self):
2273-
tags = super().__sklearn_tags__()
2274-
tags._xfail_checks = {
2275-
"check_sample_weights_invariance": (
2276-
"zero sample_weight is not equivalent to removing samples"
2277-
),
2278-
}
2279-
return tags
2280-
22812272
def _get_scorer(self):
22822273
"""Get the scorer based on the scoring method specified.
22832274
The default scoring method is `accuracy`.

sklearn/linear_model/_ridge.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2682,6 +2682,15 @@ def fit(self, X, y, sample_weight=None, **params):
26822682
super().fit(X, y, sample_weight=sample_weight, **params)
26832683
return self
26842684

2685+
def __sklearn_tags__(self):
2686+
tags = super().__sklearn_tags__()
2687+
tags._xfail_checks = {
2688+
"check_sample_weights_invariance": (
2689+
"GridSearchCV does not forward the weights to the scorer by default."
2690+
),
2691+
}
2692+
return tags
2693+
26852694

26862695
class RidgeClassifierCV(_RidgeClassifierMixin, _BaseRidgeCV):
26872696
"""Ridge classifier with built-in cross-validation.
@@ -2891,13 +2900,3 @@ def fit(self, X, y, sample_weight=None, **params):
28912900
target = Y if self.cv is None else y
28922901
super().fit(X, target, sample_weight=sample_weight, **params)
28932902
return self
2894-
2895-
def __sklearn_tags__(self):
2896-
tags = super().__sklearn_tags__()
2897-
tags.classifier_tags.multi_label = True
2898-
tags._xfail_checks = {
2899-
"check_sample_weights_invariance": (
2900-
"zero sample_weight is not equivalent to removing samples"
2901-
),
2902-
}
2903-
return tags

sklearn/tests/test_calibration.py

-44Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -941,50 +941,6 @@ def fit(self, X, y, **fit_params):
941941
pc_clf.fit(X, y, sample_weight=sample_weight)
942942

943943

944-
@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
945-
@pytest.mark.parametrize("ensemble", [True, False])
946-
def test_calibrated_classifier_cv_zeros_sample_weights_equivalence(method, ensemble):
947-
"""Check that passing removing some sample from the dataset `X` is
948-
equivalent to passing a `sample_weight` with a factor 0."""
949-
X, y = load_iris(return_X_y=True)
950-
# Scale the data to avoid any convergence issue
951-
X = StandardScaler().fit_transform(X)
952-
# Only use 2 classes and select samples such that 2-fold cross-validation
953-
# split will lead to an equivalence with a `sample_weight` of 0
954-
X = np.vstack((X[:40], X[50:90]))
955-
y = np.hstack((y[:40], y[50:90]))
956-
sample_weight = np.zeros_like(y)
957-
sample_weight[::2] = 1
958-
959-
estimator = LogisticRegression()
960-
calibrated_clf_without_weights = CalibratedClassifierCV(
961-
estimator,
962-
method=method,
963-
ensemble=ensemble,
964-
cv=2,
965-
)
966-
calibrated_clf_with_weights = clone(calibrated_clf_without_weights)
967-
968-
calibrated_clf_with_weights.fit(X, y, sample_weight=sample_weight)
969-
calibrated_clf_without_weights.fit(X[::2], y[::2])
970-
971-
# Check that the underlying fitted estimators have the same coefficients
972-
for est_with_weights, est_without_weights in zip(
973-
calibrated_clf_with_weights.calibrated_classifiers_,
974-
calibrated_clf_without_weights.calibrated_classifiers_,
975-
):
976-
assert_allclose(
977-
est_with_weights.estimator.coef_,
978-
est_without_weights.estimator.coef_,
979-
)
980-
981-
# Check that the predictions are the same
982-
y_pred_with_weights = calibrated_clf_with_weights.predict_proba(X)
983-
y_pred_without_weights = calibrated_clf_without_weights.predict_proba(X)
984-
985-
assert_allclose(y_pred_with_weights, y_pred_without_weights)
986-
987-
988944
def test_calibration_with_non_sample_aligned_fit_param(data):
989945
"""Check that CalibratedClassifierCV does not enforce sample alignment
990946
for fit parameters."""

sklearn/utils/estimator_checks.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from ..linear_model._base import LinearClassifierMixin
3737
from ..metrics import accuracy_score, adjusted_rand_score, f1_score
3838
from ..metrics.pairwise import linear_kernel, pairwise_distances, rbf_kernel
39-
from ..model_selection import ShuffleSplit, train_test_split
39+
from ..model_selection import LeaveOneGroupOut, ShuffleSplit, train_test_split
4040
from ..model_selection._validation import _safe_split
4141
from ..pipeline import make_pipeline
4242
from ..preprocessing import StandardScaler, scale
@@ -1108,6 +1108,26 @@ def check_sample_weights_invariance(name, estimator_orig, kind="ones"):
11081108
else: # pragma: no cover
11091109
raise ValueError
11101110

1111+
# when the estimator has an internal CV scheme
1112+
# we only use weights / repetitions in a specific CV group (here group=0)
1113+
if "cv" in estimator_orig.get_params():
1114+
groups2 = np.hstack(
1115+
[np.full_like(y2, 0), np.full_like(y1, 1), np.full_like(y1, 2)]
1116+
)
1117+
sw2 = np.hstack([sw2, np.ones_like(y1), np.ones_like(y1)])
1118+
X2 = np.vstack([X2, X1, X1])
1119+
y2 = np.hstack([y2, y1, y1])
1120+
splits2 = list(LeaveOneGroupOut().split(X2, groups=groups2))
1121+
estimator2.set_params(cv=splits2)
1122+
1123+
groups1 = np.hstack(
1124+
[np.full_like(y1, 0), np.full_like(y1, 1), np.full_like(y1, 2)]
1125+
)
1126+
X1 = np.vstack([X1, X1, X1])
1127+
y1 = np.hstack([y1, y1, y1])
1128+
splits1 = list(LeaveOneGroupOut().split(X1, groups=groups1))
1129+
estimator1.set_params(cv=splits1)
1130+
11111131
y1 = _enforce_estimator_tags_y(estimator1, y1)
11121132
y2 = _enforce_estimator_tags_y(estimator2, y2)
11131133

0 commit comments

Comments
 (0)
0