8000 MAINT compatibility sklearn 1.4 by glemaitre · Pull Request #1058 · scikit-learn-contrib/imbalanced-learn · GitHub
[go: up one dir, main page]

Skip to content
< 8000 ul class="pagehead-actions flex-shrink-0 d-none d-md-inline" style="padding: 2px 0;">
  • Notifications You must be signed in to change notification settings
  • Fork 1.3k
  • MAINT compatibility sklearn 1.4 #1058

    New issue

    Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

    By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

    Already on GitHub? Sign in to your account

    Merged
    merged 7 commits into from
    Jan 19, 2024
    Merged
    Show file tree
    Hide file tree
    Changes from all commits
    Commits
    File filter

    Filter by extension

    Filter by extension

    Conversations
    Failed to load comments.
    Loading
    Jump to
    Jump to file
    Failed to load files.
    Loading
    Diff view
    Diff view
    10 changes: 5 additions & 5 deletions azure-pipelines.yml
    Original file line number Diff line number Diff line change
    Expand Up @@ -115,10 +115,10 @@ jobs:
    ne(variables['Build.Reason'], 'Schedule')
    )
    matrix:
    py38_conda_forge_openblas_ubuntu_1804:
    py39_conda_forge_openblas_ubuntu_1804:
    DISTRIB: 'conda'
    CONDA_CHANNEL: 'conda-forge'
    PYTHON_VERSION: '3.8'
    PYTHON_VERSION: '3.9'
    BLAS: 'openblas'
    COVERAGE: 'false'

    Expand Down Expand Up @@ -188,7 +188,7 @@ jobs:
    pylatest_conda_tensorflow:
    DISTRIB: 'conda-latest-tensorflow'
    CONDA_CHANNEL: 'conda-forge'
    PYTHON_VERSION: '3.8'
    PYTHON_VERSION: '3.9'
    TEST_DOCS: 'true'
    TEST_DOCSTRINGS: 'true'
    CHECK_WARNINGS: 'true'
    Expand All @@ -214,7 +214,7 @@ jobs:
    pylatest_conda_keras:
    DISTRIB: 'conda-latest-keras'
    CONDA_CHANNEL: 'conda-forge'
    PYTHON_VERSION: '3.8'
    PYTHON_VERSION: '3.9'
    TEST_DOCS: 'true'
    TEST_DOCSTRINGS: 'true'
    CHECK_WARNINGS: 'true'
    Expand Down Expand Up @@ -301,7 +301,7 @@ jobs:
    py38_conda_forge_mkl:
    DISTRIB: 'conda'
    CONDA_CHANNEL: 'conda-forge'
    PYTHON_VERSION: '3.8'
    PYTHON_VERSION: '3.10'
    CHECK_WARNINGS: 'true'
    PYTHON_ARCH: '64'
    PYTEST_VERSION: '*'
    Expand Down
    5 changes: 2 additions & 3 deletions doc/ensemble.rst
    Original file line number Diff line number Diff line change
    Expand Up @@ -33,8 +33,7 @@ data set, this classifier will favor the majority classes::
    >>> from sklearn.ensemble import BaggingClassifier
    >>> from sklearn.tree import DecisionTreeClassifier
    >>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
    >>> bc = BaggingClassifier(base_estimator=DecisionTreeClassifier(),
    ... random_state=0)
    >>> bc = BaggingClassifier(DecisionTreeClassifier(), random_state=0)
    >>> bc.fit(X_train, y_train) #doctest:
    BaggingClassifier(...)
    >>> y_pred = bc.predict(X_test)
    Expand All @@ -50,7 +49,7 @@ sampling is controlled by the parameter `sampler` or the two parameters
    :class:`~imblearn.under_sampling.RandomUnderSampler`::

    >>> from imblearn.ensemble import BalancedBaggingClassifier
    >>> bbc = BalancedBaggingClassifier(base_estimator=DecisionTreeClassifier(),
    >>> bbc = BalancedBaggingClassifier(DecisionTreeClassifier(),
    ... sampling_strategy='auto',
    ... replacement=False,
    ... random_state=0)
    Expand Down
    4 changes: 4 additions & 0 deletions doc/whats_new/v0.12.rst
    Original file line number Diff line number Diff line change
    Expand Up @@ -23,9 +23,13 @@ Compatibility

    - :class:`~imblearn.ensemble.BalancedRandomForestClassifier` now support missing values
    and monotonic constraints if scikit-learn >= 1.4 is installed.

    - :class:`~imblearn.pipeline.Pipeline` support metadata routing if scikit-learn >= 1.4
    is installed.

    - Compatibility with scikit-learn 1.4.
    :pr:`1058` by :user:`Guillaume Lemaitre <glemaitre>`.

    Deprecations
    ............

    Expand Down
    84 changes: 26 additions & 58 deletions imblearn/ensemble/_bagging.py
    Original file line number Diff line number Diff line change
    Expand Up @@ -5,7 +5,6 @@
    # License: MIT

    import copy
    import inspect
    import numbers
    import warnings

    Expand All @@ -15,6 +14,7 @@
    from sklearn.ensemble import BaggingClassifier
    from sklearn.ensemble._bagging import _parallel_decision_function
    from sklearn.ensemble._base import _partition_estimators
    from sklearn.exceptions import NotFittedError
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.utils import parse_version
    from sklearn.utils.validation import check_is_fitted
    Expand Down Expand Up @@ -121,30 +121,13 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):

    .. versionadded:: 0.8

    base_estimator : estimator object, default=None
    The base estimator to fit on random subsets of the dataset.
    If None, then the base estimator is a decision tree.

    .. deprecated:: 0.10
    `base_estimator` was renamed to `estimator` in version 0.10 and
    will be removed in 0.12.

    Attributes
    ----------
    estimator_ : estimator
    The base estimator from which the ensemble is grown.

    .. versionadded:: 0.10

    base_estimator_ : estimator
    The base estimator from which the ensemble is grown.

    .. deprecated:: 1.2
    `base_estimator_` is deprecated in `scikit-learn` 1.2 and will be
    removed in 1.4. Use `estimator_` instead. When the minimum version
    of `scikit-learn` supported by `imbalanced-learn` will reach 1.4,
    this attribute will be removed.

    n_features_ : int
    The number of features when `fit` is performed.

    Expand Down Expand Up @@ -266,7 +249,7 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
    """

    # make a deepcopy to not modify the original dictionary
    if sklearn_version >= parse_version("1.3"):
    if sklearn_version >= parse_version("1.4"):
    _parameter_constraints = copy.deepcopy(BaggingClassifier._parameter_constraints)
    else:
    _parameter_constraints = copy.deepcopy(_bagging_parameter_constraints)
    Expand All @@ -283,6 +266,9 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
    "sampler": [HasMethods(["fit_resample"]), None],
    }
    )
    # TODO: remove when minimum supported version of scikit-learn is 1.4
    if "base_estimator" in _parameter_constraints:
    del _parameter_constraints["base_estimator"]

    def __init__(
    self,
    Expand All @@ -301,18 +287,8 @@ def __init__(
    random_state=None,
    verbose=0,
    sampler=None,
    base_estimator="deprecated",
    ):
    # TODO: remove when supporting scikit-learn>=1.2
    bagging_classifier_signature = inspect.signature(super().__init__)
    estimator_params = {"base_estimator": base_estimator}
    if "estimator" in bagging_classifier_signature.parameters:
    estimator_params["estimator"] = estimator
    else:
    self.estimator = estimator

    super().__init__(
    **estimator_params,
    n_estimators=n_estimators,
    max_samples=max_samples,
    max_features=max_features,
    Expand All @@ -324,6 +300,7 @@ def __init__(
    random_state=random_state,
    verbose=verbose,
    )
    self.estimator = estimator
    self.sampling_strategy = sampling_strategy
    self.replacement = replacement
    self.sampler = sampler
    Expand All @@ -349,42 +326,17 @@ def _validate_y(self, y):
    def _validate_estimator(self, default=DecisionTreeClassifier()):
    """Check the estimator and the n_estimator attribute, set the
    `estimator_` attribute."""
    if self.estimator is not None and (
    self.base_estimator not in [None, "deprecated"]
    ):
    raise ValueError(
    "Both `estimator` and `base_estimator` were set. Only set `estimator`."
    )

    if self.estimator is not None:
    base_estimator = clone(self.estimator)
    elif self.base_estimator not in [None, "deprecated"]:
    warnings.warn(
    "`base_estimator` was renamed to `estimator` in version 0.10 and "
    "will be removed in 0.12.",
    FutureWarning,
    )
    base_estimator = clone(self.base_estimator)
    estimator = clone(self.estimator)
    else:
    base_estimator = clone(default)
    estimator = clone(default)

    if self.sampler_._sampling_type != "bypass":
    self.sampler_.set_params(sampling_strategy=self._sampling_strategy)

    self._estimator = Pipeline(
    [("sampler", self.sampler_), ("classifier", base_estimator)]
    self.estimator_ = Pipeline(
    [("sampler", self.sampler_), ("classifier", estimator)]
    )
    try:
    # scikit-learn < 1.2
    self.base_estimator_ = self._estimator
    except AttributeError:
    pass

    # TODO: remove when supporting scikit-learn>=1.4
    @property
    def estimator_(self):
    """Estimator used to grow the ensemble."""
    return self._estimator

    # TODO: remove when supporting scikit-learn>=1.2
    @property
    Expand Down Expand Up @@ -483,6 +435,22 @@ def decision_function(self, X):

    return decisions

    @property
    def base_estimator_(self):
    """Attribute for older sklearn version compatibility."""
    error = AttributeError(
    f"{self.__class__.__name__} object has no attribute 'base_estimator_'."
    )
    if sklearn_version < parse_version("1.2"):
    # The base class require to have the attribute defined. For scikit-learn
    # > 1.2, we are going to raise an error.
    try:
    check_is_fitted(self)
    return self.estimator_
    except NotFittedError:
    raise error
    raise error

    def _more_tags(self):
    tags = super()._more_tags()
    tags_key = "_xfail_checks"
    Expand Down
    Loading
    0