From 64a365127020580718e2a54436cec6b8d7fe265d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 24 Sep 2021 17:15:29 +0200 Subject: [PATCH 1/6] TST add common tests for meta-estimators --- sklearn/tests/test_metaestimators.py | 61 +++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 10 deletions(-) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index e743741f6fa43..b99d78f39a237 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from sklearn.base import BaseEstimator +from sklearn.base import clone, BaseEstimator from sklearn.base import is_regressor from sklearn.datasets import make_classification from sklearn.utils import all_estimators @@ -21,7 +21,7 @@ from sklearn.exceptions import NotFittedError from sklearn.semi_supervised import SelfTrainingClassifier from sklearn.linear_model import Ridge, LogisticRegression -from sklearn.preprocessing import StandardScaler, MaxAbsScaler +from sklearn.preprocessing import StandardScaler, MaxAbsScaler, FunctionTransformer class DelegatorData: @@ -185,21 +185,35 @@ def score(self, X, y, *args, **kwargs): ) -def _generate_meta_estimator_instances_with_pipeline(): - """Generate instances of meta-estimators fed with a pipeline +def _generate_meta_estimator_instances_with_pipeline(first_step=None): + """Generate instances of meta-estimators fed with a pipeline. Are considered meta-estimators all estimators accepting one of "estimator", "base_estimator" or "estimators". + + Parameters + ---------- + first_step : None or estimator instance, default=None + The first step of the pipeline to pass to the meta-estimator. + If `None`, a `TfidfVectorizer` is used. + + Yields + ------ + estimator : estimator instance + A meta-estimator instance. + """ + if first_step is None: + first_step = TfidfVectorizer() for _, Estimator in sorted(all_estimators()): sig = set(signature(Estimator).parameters) if "estimator" in sig or "base_estimator" in sig or "regressor" in sig: if is_regressor(Estimator): - estimator = make_pipeline(TfidfVectorizer(), Ridge()) + estimator = make_pipeline(first_step, Ridge()) param_grid = {"ridge__alpha": [0.1, 1.0]} else: - estimator = make_pipeline(TfidfVectorizer(), LogisticRegression()) + estimator = make_pipeline(first_step, LogisticRegression()) param_grid = {"logisticregression__C": [0.1, 1.0]} if "param_grid" in sig or "param_distributions" in sig: @@ -212,10 +226,10 @@ def _generate_meta_estimator_instances_with_pipeline(): elif "transformer_list" in sig: # FeatureUnion transformer_list = [ - ("trans1", make_pipeline(TfidfVectorizer(), MaxAbsScaler())), + ("trans1", make_pipeline(first_step, MaxAbsScaler())), ( "trans2", - make_pipeline(TfidfVectorizer(), StandardScaler(with_mean=False)), + make_pipeline(first_step, StandardScaler(with_mean=False)), ), ] yield Estimator(transformer_list) @@ -224,8 +238,8 @@ def _generate_meta_estimator_instances_with_pipeline(): # stacking, voting if is_regressor(Estimator): estimator = [ - ("est1", make_pipeline(TfidfVectorizer(), Ridge(alpha=0.1))), - ("est2", make_pipeline(TfidfVectorizer(), Ridge(alpha=1))), + ("est1", make_pipeline(first_step, Ridge(alpha=0.1))), + ("est2", make_pipeline(first_step, Ridge(alpha=1))), ] else: estimator = [ @@ -299,3 +313,30 @@ def test_meta_estimators_delegate_data_validation(estimator): # n_features_in_ should not be defined since data is not tabular data. assert not hasattr(estimator, "n_features_in_") + + +@pytest.mark.parametrize( + "estimator_orig", + [ + est + for est in _generate_meta_estimator_instances_with_pipeline( + first_step=FunctionTransformer() + ) + ], + ids=_get_meta_estimator_id, +) +def test_meta_estimators_sample_weight_with_pipeline(estimator_orig): + estimator = clone(estimator_orig) + rng = np.random.RandomState(0) + set_random_state(estimator) + + n_samples = 100 + X = rng.randn(n_samples, 10) + y = rng.randint(3, size=n_samples) + y = _enforce_estimator_tags_y(estimator, y) + + sample_weight = np.ones(y.shape[0]) + sample_weight[0] = 10.0 + + with pytest.raises((ValueError, TypeError), match="sample.*weight"): + estimator.fit(X, y, sample_weight=sample_weight) From 80cd0aa8cd854fbe1d042a5e5b3c5f15d3adec15 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 24 Sep 2021 17:36:43 +0200 Subject: [PATCH 2/6] iter --- doc/whats_new/v1.0.rst | 18 ++++++++++++++++++ sklearn/calibration.py | 32 +++++++++++++------------------- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index d8776653cd9e8..4053eedb7e8d0 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -2,6 +2,24 @@ .. currentmodule:: sklearn +.. _changes_1_0_1: + +Version 1.0.1 +============= + +**In Development** + +Changelog +--------- + +:mod:`sklearn.calibration` +.......................... + +- |Fix| Raise an error in :class:`calibration.CalibratedClassifierCV` when + `sample_weight` are used together with a :class:`pipeline.Pipeline` instead + of silently ignoring `sample_weight`. + :pr:`21143` by :user:`Guillaume Lemaitre `. + .. _changes_1_0: Version 1.0.0 diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 9ede41a775c3e..5666c4a691c68 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -7,7 +7,6 @@ # # License: BSD 3 clause -import warnings from inspect import signature from functools import partial @@ -37,7 +36,10 @@ from .utils.multiclass import check_classification_targets from .utils.fixes import delayed -from .utils.validation import check_is_fitted, check_consistent_length +from .utils.validation import ( + check_is_fitted, + check_consistent_length, +) from .utils.validation import _check_sample_weight, _num_samples from .utils import _safe_indexing from .isotonic import IsotonicRegression @@ -302,16 +304,12 @@ def fit(self, X, y, sample_weight=None): # sample_weight checks fit_parameters = signature(base_estimator.fit).parameters - supports_sw = "sample_weight" in fit_parameters + if "sample_weight" not in fit_parameters and sample_weight is not None: + raise ValueError( + f"The estimator {base_estimator} does not support sample_weight" + ) if sample_weight is not None: - sample_weight = _check_sample_weight(sample_weight, X) - if not supports_sw: - estimator_name = type(base_estimator).__name__ - warnings.warn( - f"Since {estimator_name} does not support " - "sample_weights, sample weights will only be" - " used for the calibration itself." - ) + _check_sample_weight(sample_weight, X) # Check that each cross-validation fold can have at least one # example per class @@ -343,7 +341,6 @@ def fit(self, X, y, sample_weight=None): test=test, method=self.method, classes=self.classes_, - supports_sw=supports_sw, sample_weight=sample_weight, ) for train, test in cv.split(X, y) @@ -364,7 +361,7 @@ def fit(self, X, y, sample_weight=None): pred_method, method_name, X, n_classes ) - if sample_weight is not None and supports_sw: + if sample_weight is not None: this_estimator.fit(X, y, sample_weight) else: this_estimator.fit(X, y) @@ -443,7 +440,7 @@ def _more_tags(self): def _fit_classifier_calibrator_pair( - estimator, X, y, train, test, supports_sw, method, classes, sample_weight=None + estimator, X, y, train, test, method, classes, sample_weight=None ): """Fit a classifier/calibration pair on a given train/test split. @@ -468,9 +465,6 @@ def _fit_classifier_calibrator_pair( test : ndarray, shape (n_test_indicies,) Indices of the testing subset. - supports_sw : bool - Whether or not the `estimator` supports sample weights. - method : {'sigmoid', 'isotonic'} Method to use for calibration. @@ -486,14 +480,14 @@ def _fit_classifier_calibrator_pair( """ X_train, y_train = _safe_indexing(X, train), _safe_indexing(y, train) X_test, y_test = _safe_indexing(X, test), _safe_indexing(y, test) - if supports_sw and sample_weight is not None: + if sample_weight is not None: sw_train = _safe_indexing(sample_weight, train) sw_test = _safe_indexing(sample_weight, test) else: sw_train = None sw_test = None - if supports_sw: + if sample_weight is not None: estimator.fit(X_train, y_train, sample_weight=sw_train) else: estimator.fit(X_train, y_train) From cce6f9a32b6cb34d0c09e23533fac8a0c2c59b32 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 24 Sep 2021 17:42:15 +0200 Subject: [PATCH 3/6] iter --- sklearn/tests/test_metaestimators.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index b99d78f39a237..21f05e02729c1 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -315,6 +315,13 @@ def test_meta_estimators_delegate_data_validation(estimator): assert not hasattr(estimator, "n_features_in_") +METAESTIMATORS_ACCEPTING_SAMPLE_WEIGHTS_IN_PIPELINE = [ + # AdaBoostRegressor can safely accepts `sample_weight` even with a `Pipeline` + # because the weights are not used when calling `pipeline.fit`. + "AdaBoostRegressor", +] + + @pytest.mark.parametrize( "estimator_orig", [ @@ -322,10 +329,13 @@ def test_meta_estimators_delegate_data_validation(estimator): for est in _generate_meta_estimator_instances_with_pipeline( first_step=FunctionTransformer() ) + if est.__class__.__name__ + not in METAESTIMATORS_ACCEPTING_SAMPLE_WEIGHTS_IN_PIPELINE ], ids=_get_meta_estimator_id, ) def test_meta_estimators_sample_weight_with_pipeline(estimator_orig): + """Check that passing a `Pipeline` with `sample_weight` raises an error.""" estimator = clone(estimator_orig) rng = np.random.RandomState(0) set_random_state(estimator) From 4fb184c492324336ff613182baaca7a4fce43d58 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 27 Sep 2021 14:52:15 +0200 Subject: [PATCH 4/6] iter --- sklearn/calibration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 5666c4a691c68..0c1efddc66095 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -309,7 +309,7 @@ def fit(self, X, y, sample_weight=None): f"The estimator {base_estimator} does not support sample_weight" ) if sample_weight is not None: - _check_sample_weight(sample_weight, X) + sample_weight = _check_sample_weight(sample_weight, X) # Check that each cross-validation fold can have at least one # example per class From 000c3f85ebe64ce6b8e493bdd027ed88c140dadd Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 27 Sep 2021 14:54:19 +0200 Subject: [PATCH 5/6] DOC add more details for future behaviour --- sklearn/tests/test_metaestimators.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index 21f05e02729c1..684870957832b 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -316,8 +316,8 @@ def test_meta_estimators_delegate_data_validation(estimator): METAESTIMATORS_ACCEPTING_SAMPLE_WEIGHTS_IN_PIPELINE = [ - # AdaBoostRegressor can safely accepts `sample_weight` even with a `Pipeline` - # because the weights are not used when calling `pipeline.fit`. + # AdaBoostRegressor can safely accepts `sample_weight` even with a + # `Pipeline` because the weights are not used when calling `pipeline.fit`. "AdaBoostRegressor", ] @@ -335,7 +335,11 @@ def test_meta_estimators_delegate_data_validation(estimator): ids=_get_meta_estimator_id, ) def test_meta_estimators_sample_weight_with_pipeline(estimator_orig): - """Check that passing a `Pipeline` with `sample_weight` raises an error.""" + """Check that passing a `Pipeline` with `sample_weight` raises an error. + + FIXME: in the future, `Pipeline` should be able to delegate `sample_weight` + to the inner estimator(s). An error should not be raised then. + """ estimator = clone(estimator_orig) rng = np.random.RandomState(0) set_random_state(estimator) From 54ae36cfde1a20dae7b79c50521d1559a3bd9ea2 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 27 Sep 2021 15:07:14 +0200 Subject: [PATCH 6/6] iter --- doc/whats_new/v1.0.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 4053eedb7e8d0..884a493b1361f 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -15,9 +15,9 @@ Changelog :mod:`sklearn.calibration` .......................... -- |Fix| Raise an error in :class:`calibration.CalibratedClassifierCV` when - `sample_weight` are used together with a :class:`pipeline.Pipeline` instead - of silently ignoring `sample_weight`. +- |Fix| Raise an error instead of silently ignoring `sample_weight` in + :class:`calibration.CalibratedClassifierCV` when the underlying estimator + does not support it. :pr:`21143` by :user:`Guillaume Lemaitre `. .. _changes_1_0: