|
3 | 3 | """
|
4 | 4 |
|
5 | 5 | import numpy as np
|
6 |
| -import warnings |
7 | 6 |
|
8 | 7 | from sklearn import datasets
|
9 | 8 | from sklearn.base import clone
|
|
12 | 11 | from sklearn.ensemble.gradient_boosting import ZeroEstimator
|
13 | 12 | from sklearn.metrics import mean_squared_error
|
14 | 13 | from sklearn.utils import check_random_state, tosequence
|
15 |
| -from sklearn.utils.testing import assert_almost_equal, clean_warning_registry |
| 14 | +from sklearn.utils.testing import assert_almost_equal |
16 | 15 | from sklearn.utils.testing import assert_array_almost_equal
|
17 | 16 | from sklearn.utils.testing import assert_array_equal
|
18 | 17 | from sklearn.utils.testing import assert_equal
|
@@ -445,6 +444,24 @@ def test_staged_predict_proba():
|
445 | 444 | assert_array_equal(clf.predict_proba(X_test), staged_proba)
|
446 | 445 |
|
447 | 446 |
|
| 447 | +def test_staged_functions_defensive(): |
| 448 | + # test that staged_functions make defensive copies |
| 449 | + rng = np.random.RandomState(0) |
| 450 | + X = rng.uniform(size=(10, 3)) |
| 451 | + y = (4 * X[:, 0]).astype(np.int) + 1 # don't predict zeros |
| 452 | + for estimator in [GradientBoostingRegressor(), |
| 453 | + GradientBoostingClassifier()]: |
| 454 | + estimator.fit(X, y) |
| 455 | + for func in ['predict', 'decision_function', 'predict_proba']: |
| 456 | + staged_func = getattr(estimator, "staged_" + func, None) |
| 457 | + if staged_func is None: |
| 458 | + # regressor has no staged_predict_proba |
| 459 | + continue |
| 460 | + staged_result = list(staged_func(X)) |
| 461 | + staged_result[1][:] = 0 |
| 462 | + assert_true(np.all(staged_result[0] != 0)) |
| 463 | + |
| 464 | + |
448 | 465 | def test_serialization():
|
449 | 466 | """Check model serialization."""
|
450 | 467 | clf = GradientBoostingClassifier(n_estimators=100, random_state=1)
|
|
0 commit comments