8000 AdaBoost: allow base_estimator=None (#26242) · scikit-learn/scikit-learn@72a6049 · GitHub
[go: up one dir, main page]

Skip to content

Commit 72a6049

Browse files
authored
AdaBoost: allow base_estimator=None (#26242)
1 parent 28d65c5 commit 72a6049

File tree

4 files changed

+35
-3
lines changed

4 files changed

+35
-3
lines changed

doc/whats_new/v1.3.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ Changelog
281281
pandas' conventions.
282282
:pr:`25629` by `Thomas Fan`_.
283283

284+
- |Fix| Fix deprecation of `base_estimator` in :class:`ensemble.AdaBoostClassifier`
285+
and :class:`ensemble.AdaBoostRegressor` that was introduced in :pr:`23819`.
286+
:pr:`26242` by :user:`Marko Toplak <markotoplak>`.
287+
284288
:mod:`sklearn.exception`
285289
........................
286290
- |Feature| Added :class:`exception.InconsistentVersionWarning` which is raised

sklearn/ensemble/_base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,18 @@ def _validate_estimator(self, default=None):
157157

158158
if self.estimator is not None:
159159
self.estimator_ = self.estimator
160-
elif self.base_estimator not in [None, "deprecated"]:
160+
elif self.base_estimator != "deprecated":
161161
warnings.warn(
162162
(
163163
"`base_estimator` was renamed to `estimator` in version 1.2 and "
164164
"will be removed in 1.4."
165165
),
166166
FutureWarning,
167167
)
168-
self.estimator_ = self.base_estimator
168+
if self.base_estimator is not None:
169+
self.estimator_ = self.base_estimator
170+
else:
171+
self.estimator_ = default
169172
else:
170173
self.estimator_ = default
171174

sklearn/ensemble/_weight_boosting.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ class BaseWeightBoosting(BaseEnsemble, metaclass=ABCMeta):
6464
"n_estimators": [Interval(Integral, 1, None, closed="left")],
6565
"learning_rate": [Interval(Real, 0, None, closed="neither")],
6666
"random_state": ["random_state"],
67-
"base_estimator": [HasMethods(["fit", "predict"]), StrOptions({"deprecated"})],
67+
"base_estimator": [
68+
HasMethods(["fit", "predict"]),
69+
StrOptions({"deprecated"}),
70+
None,
71+
],
6872
}
6973

7074
@abstractmethod

sklearn/ensemble/tests/test_weight_boosting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,27 @@ def test_base_estimator_argument_deprecated(AdaBoost, Estimator):
613613
model.fit(X, y)
614614

615615

616+
# TODO(1.4): remove in 1.4
617+
@pytest.mark.parametrize(
618+
"AdaBoost",
619+
[
620+
AdaBoostClassifier,
621+
AdaBoostRegressor,
622+
],
623+
)
624+
def test_base_estimator_argument_deprecated_none(AdaBoost):
625+
X = np.array([[1, 2], [3, 4]])
626+
y = np.array([1, 0])
627+
model = AdaBoost(base_estimator=None)
628+
629+
warn_msg = (
630+
"`base_estimator` was renamed to `estimator` in version 1.2 and "
631+
"will be removed in 1.4."
632+
)
633+
with pytest.warns(FutureWarning, match=warn_msg):
634+
model.fit(X, y)
635+
636+
616637
# TODO(1.4): remove in 1.4
617638
@pytest.mark.parametrize(
618639
"AdaBoost",

0 commit comments

Comments
 (0)
0