8000 FIX GBDT init parameter when it's a pipeline (#13472) · scikit-learn/scikit-learn@d6b368e · GitHub
[go: up one dir, main page]

Skip to content

Commit d6b368e

Browse files
NicolasHugjnothman
authored andcommitted
FIX GBDT init parameter when it's a pipeline (#13472)
1 parent 49cdee6 commit d6b368e

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

sklearn/ensemble/gradient_boosting.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,20 +1478,25 @@ def fit(self, X, y, sample_weight=None, monitor=None):
14781478
raw_predictions = np.zeros(shape=(X.shape[0], self.loss_.K),
14791479
dtype=np.float64)
14801480
else:
1481-
try:
1482-
self.init_.fit(X, y, sample_weight=sample_weight)
1483-
except TypeError:
1484-
if sample_weight_is_none:
1485-
self.init_.fit(X, y)
1486-
else:
1487-
raise ValueError(
1488-
"The initial estimator {} does not support sample "
1489-
"weights.".format(self.init_.__class__.__name__))
1481+
# XXX clean this once we have a support_sample_weight tag
1482+
if sample_weight_is_none:
1483+
self.init_.fit(X, y)
1484+
else:
1485+
msg = ("The initial estimator {} does not support sample "
1486+
"weights.".format(self.init_.__class__.__name__))
1487+
try:
1488+
self.init_.fit(X, 10000 y, sample_weight=sample_weight)
1489+
except TypeError: # regular estimator without SW support
1490+
raise ValueError(msg)
1491+
except ValueError as e:
1492+
if 'not enough values to unpack' in str(e): # pipeline
1493+
raise ValueError(msg) from e
1494+
else: # regular estimator whose input checking failed
1495+
raise
14901496

14911497
raw_predictions = \
14921498
self.loss_.get_init_raw_predictions(X, self.init_)
14931499

1494-
14951500
begin_at_stage = 0
14961501

14971502
# The rng state must be preserved if warm_start is True

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
from sklearn.exceptions import DataConversionWarning
4040
from sklearn.exceptions import NotFittedError
4141
from sklearn.dummy import DummyClassifier, DummyRegressor
42+
from sklearn.pipeline import make_pipeline
43+
from sklearn.linear_model import LinearRegression
44+
from sklearn.svm import NuSVR
4245

4346

4447
GRADIENT_BOOSTING_ESTIMATORS = [GradientBoostingClassifier,
@@ -1378,6 +1381,33 @@ def test_gradient_boosting_with_init(gb, dataset_maker, init_estimator):
13781381
gb(init=init_est).fit(X, y, sample_weight=sample_weight)
13791382

13801383

1384+
def test_gradient_boosting_with_init_pipeline():
1385+
# Check that the init estimator can be a pipeline (see issue #13466)
1386+
1387+
X, y = make_regression(random_state=0)
1388+
init = make_pipeline(LinearRegression())
1389+
gb = GradientBoostingRegressor(init=init)
1390+
gb.fit(X, y) # pipeline without sample_weight works fine
1391+
1392+
with pytest.raises(
1393+
ValueError,
1394+
match='The initial estimator Pipeline does not support sample '
1395+
'weights'):
1396+
gb.fit(X, y, sample_weight=np.ones(X.shape[0]))
1397+
1398+
# Passing sample_weight to a pipeline raises a ValueError. This test makes
1399+
# sure we make the distinction between ValueError raised by a pipeline that
1400+
# was passed sample_weight, and a ValueError raised by a regular estimator
1401+
# whose input checking failed.
1402+
with pytest.raises(
1403+
ValueError,
1404+
match='nu <= 0 or nu > 1'):
1405+
# Note that NuSVR properly supports sample_weight
1406+
init = NuSVR(gamma='auto', nu=1.5)
1407+
gb = GradientBoostingRegressor(init=init)
1408+
gb.fit(X, y, sample_weight=np.ones(X.shape[0]))
1409+
1410+
13811411
@pytest.mark.parametrize('estimator, missing_method', [
13821412
(GradientBoostingClassifier(init=LinearSVC()), 'predict_proba'),
13831413
(GradientBoostingRegressor(init=OneHotEncoder()), 'predict')

0 commit comments

Comments
 (0)
0