8000 Merge pull request #4165 from amueller/gbrt_staged_defensive_copies · ogrisel/scikit-learn@0cf9314 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0cf9314

Browse files
committed
Merge pull request scikit-learn#4165 from amueller/gbrt_staged_defensive_copies
[MRG+1] make defensive copies in GradientBoosting*.staged_decision function.
2 parents f6af488 + 0036dd0 commit 0cf9314

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

sklearn/ensemble/gradient_boosting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1120,7 +1120,7 @@ def staged_decision_function(self, X):
11201120
score = self._init_decision_function(X)
11211121
for i in range(self.estimators_.shape[0]):
11221122
predict_stage(self.estimators_, i, X, self.learning_rate, score)
1123-
yield score
1123+
yield score.copy()
11241124

11251125
@property
11261126
def feature_importances_(self):

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44

55
import numpy as np
6-
import warnings
76

87
from sklearn import datasets
98
from sklearn.base import clone
@@ -12,7 +11,7 @@
1211
from sklearn.ensemble.gradient_boosting import ZeroEstimator
1312
from sklearn.metrics import mean_squared_error
1413
from sklearn.utils import check_random_state, tosequence
15-
from sklearn.utils.testing import assert_almost_equal, clean_warning_registry
14+
from sklearn.utils.testing import assert_almost_equal
1615
from sklearn.utils.testing import assert_array_almost_equal
1716
from sklearn.utils.testing import assert_array_equal
1817
from sklearn.utils.testing import assert_equal
@@ -445,6 +444,24 @@ def test_staged_predict_proba():
445444
assert_array_equal(clf.predict_proba(X_test), staged_proba)
446445

447446

447+
def test_staged_functions_defensive():
448+
# test that staged_functions make defensive copies
449+
rng = np.random.RandomState(0)
450+
X = rng.uniform(size=(10, 3))
451+
y = (4 * X[:, 0]).astype(np.int) + 1 # don't predict zeros
452+
for estimator in [GradientBoostingRegressor(),
453+
GradientBoostingClassifier()]:
454+
estimator.fit(X, y)
455+
for func in ['predict', 'decision_function', 'predict_proba']:
456+
staged_func = getattr(estimator, "staged_" + func, None)
457+
if staged_func is None:
458+
# regressor has no staged_predict_proba
459+
continue
460+
staged_result = list(staged_func(X))
461+
staged_result[1][:] = 0
462+
assert_true(np.all(staged_result[0] != 0))
463+
464+
448465
def test_serialization():
449466
"""Check model serialization."""
450467
clf = GradientBoostingClassifier(n_estimators=100, random_state=1)

0 commit comments

Comments
 (0)
0