8000 FIX bagging with metadata routing and estimator implement __len__ (#2… · patrickkwang/scikit-learn@cfd8091 · GitHub
[go: up one dir, main page]

Skip to content

Commit cfd8091

Browse files
glemaitreadam2392
andauthored
FIX bagging with metadata routing and estimator implement __len__ (scikit-learn#28734)
Co-authored-by: Adam Li <adam2392@gmail.com>
1 parent 0e19a48 commit cfd8091

File tree

2 files changed

+55
-22
lines changed

2 files changed

+55
-22
lines changed

sklearn/ensemble/_bagging.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,6 @@ def _parallel_build_estimators(
113113
estimators = []
114114
estimators_features = []
115115

116-
request_or_router = get_routing_for_object(ensemble.estimator_)
117-
118116
# TODO: (slep6) remove if condition for unrouted sample_weight when metadata
119117
# routing can't be disabled.
120118
support_sample_weight = has_fit_parameter(ensemble.estimator_, "sample_weight")
@@ -164,9 +162,14 @@ def _parallel_build_estimators(
164162
# Note: Row sampling can be achieved either through setting sample_weight or
165163
# by indexing. The former is more efficient. Therefore, use this method
166164
# if possible, otherwise use indexing.
167-
if (
168-
_routing_enabled() and request_or_router.consumes("fit", ("sample_weight",))
169-
) or (not _routing_enabled() and support_sample_weight):
165+
if _routing_enabled():
166+
request_or_router = get_routing_for_object(ensemble.estimator_)
167+
consumes_sample_weight = request_or_router.consumes(
168+
"fit", ("sample_weight",)
169+
)
170+
else:
171+
consumes_sample_weight = support_sample_weight
172+
if consumes_sample_weight:
170173
# Draw sub samples, using sample weights, and then fit
171174
curr_sample_weight = _check_sample_weight(
172175
fit_params_.pop("sample_weight", None), X
@@ -635,6 +638,9 @@ def get_metadata_routing(self):
635638
def _get_estimator(self):
636639
"""Resolve which estimator to return."""
637640

641+
def _more_tags(self):
642+
return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}
643+
638644

639645
class BaggingClassifier(ClassifierMixin, BaseBagging):
640646
"""A Bagging classifier.
@@ -835,7 +841,9 @@ def __init__(
835841

836842
def _get_estimator(self):
837843
"""Resolve which estimator to return (default is DecisionTreeClassifier)"""
838-
return self.estimator or DecisionTreeClassifier()
844+
if self.estimator is None:
845+
return DecisionTreeClassifier()
846+
return self.estimator
839847

840848
def _set_oob_score(self, X, y):
841849
n_samples = y.shape[0]
@@ -1059,14 +1067,6 @@ def decision_function(self, X):
10591067

10601068
return decisions
10611069

1062-
def _more_tags(self):
1063-
if self.estimator is None:
1064-
estimator = DecisionTreeClassifier()
1065-
else:
1066-
estimator = self.estimator
1067-
1068-
return {"allow_nan": _safe_tags(estimator, "allow_nan")}
1069-
10701070

10711071
class BaggingRegressor(RegressorMixin, BaseBagging):
10721072
"""A Bagging regressor.
@@ -1328,13 +1328,8 @@ def _set_oob_score(self, X, y):
13281328
self.oob_prediction_ = predictions
13291329
self.oob_score_ = r2_score(y, predictions)
13301330

1331-
def _more_tags(self):
1332-
if self.estimator is None:
1333-
estimator = DecisionTreeRegressor()
1334-
else:
1335-
estimator = self.estimator
1336-
return {"allow_nan": _safe_tags(estimator, "allow_nan")}
1337-
13381331
def _get_estimator(self):
13391332
"""Resolve which estimator to return (default is DecisionTreeClassifier)"""
1340-
return self.estimator or DecisionTreeRegressor()
1333+
if self.estimator is None:
1334+
return DecisionTreeRegressor()
1335+
return self.estimator

sklearn/ensemble/tests/test_bagging.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,19 @@
1010
import numpy as np
1111
import pytest
1212

13+
import sklearn
1314
from sklearn.base import BaseEstimator
1415
from sklearn.datasets import load_diabetes, load_iris, make_hastie_10_2
1516
from sklearn.dummy import DummyClassifier, DummyRegressor
1617
from sklearn.ensemble import (
18+
AdaBoostClassifier,
19+
AdaBoostRegressor,
1720
BaggingClassifier,
1821
BaggingRegressor,
1922
HistGradientBoostingClassifier,
2023
HistGradientBoostingRegressor,
24+
RandomForestClassifier,
25+
RandomForestRegressor,
2126
)
2227
from sklearn.feature_selection import SelectKBest
2328
from sklearn.linear_model import LogisticRegression, Perceptron
@@ -936,3 +941,36 @@ def fit(self, X, y):
936941
def test_bagging_allow_nan_tag(bagging, expected_allow_nan):
937942
"""Check that bagging inherits allow_nan tag."""
938943
assert bagging._get_tags()["allow_nan"] == expected_allow_nan
944+
945+
946+
@pytest.mark.parametrize(
947+
"model",
948+
[
949+
BaggingClassifier(
950+
estimator=RandomForestClassifier(n_estimators=1), n_estimators=1
951+
),
952+
BaggingRegressor(
953+
estimator=RandomForestRegressor(n_estimators=1), n_estimators=1
954+
),
955+
],
956+
)
957+
def test_bagging_with_metadata_routing(model):
958+
"""Make sure that metadata routing works with non-default estimator."""
959+
with sklearn.config_context(enable_metadata_routing=True):
960+
model.fit(iris.data, iris.target)
961+
962+
963+
@pytest.mark.parametrize(
964+
"model",
965+
[
966+
BaggingClassifier(
967+
estimator=AdaBoostClassifier(n_estimators=1, algorithm="SAMME"),
968+
n_estimators=1,
969+
),
970+
BaggingRegressor(estimator=AdaBoostRegressor(n_estimators=1), n_estimators=1),
971+
],
972+
)
973+
def test_bagging_without_support_metadata_routing(model):
974+
"""Make sure that we still can use an estimator that does not implement the
975+
metadata routing."""
976+
model.fit(iris.data, iris.target)

0 commit comments

Comments
 (0)
0