8000 Make BaggingClassifier use if_delegate_has_method in decision_function · scikit-learn/scikit-learn@8f2780f · GitHub
[go: up one dir, main page]

Skip to content

Commit 8f2780f

Browse files
amuellerogrisel
authored andcommitted
Make BaggingClassifier use if_delegate_has_method in decision_function
1 parent feceab6 commit 8f2780f

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

sklearn/ensemble/bagging.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..utils.random import sample_without_replacement
2222
from ..utils.validation import has_fit_parameter, check_is_fitted
2323
from ..utils.fixes import bincount
24+
from ..utils.metaestimators import if_delegate_has_method
2425

2526
from .base import BaseEnsemble, _partition_estimators
2627

@@ -52,7 +53,6 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight,
5253
support_sample_weight = has_fit_parameter(ensemble.base_estimator_,
5354
"sample_weight")
5455

55-
5656
# Build estimators
5757
estimators = []
5858
estimators_samples = []
@@ -626,6 +626,7 @@ def predict_log_proba(self, X):
626626
else:
627627
return np.log(self.predict_proba(X))
628628

629+
@if_delegate_has_method(delegate='base_estimator')
629630
def decision_function(self, X):
630631
"""Average of the decision functions of the base classifiers.
631632
@@ -645,9 +646,6 @@ def decision_function(self, X):
645646
646647
"""
647648
check_is_fitted(self, "classes_")
648-
# Trigger an exception if not supported
649-
if not hasattr(self.base_estimator_, "decision_function"):
650-
raise NotImplementedError
651649

652650
# Check data
653651
X = check_array(X)

sklearn/ensemble/tests/test_bagging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils.testing import assert_greater
1515
from sklearn.utils.testing import assert_less
1616
from sklearn.utils.testing import assert_true
17+
from sklearn.utils.testing import assert_false
1718
from sklearn.utils.testing import assert_warns
1819

1920
from sklearn.dummy import DummyClassifier, DummyRegressor
@@ -403,8 +404,7 @@ def test_error():
403404
BaggingClassifier(base, max_features="foobar").fit, X, y)
404405

405406
# Test support of decision_function
406-
assert_raises(NotImplementedError,
407-
BaggingClassifier(base).fit(X, y).decision_function, X)
407+
assert_false(hasattr(BaggingClassifier(base).fit(X, y), 'decision_function'))
408408

409409

410410
def test_parallel_classification():

sklearn/tests/test_metaestimators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sklearn.pipeline import Pipeline
1212
from sklearn.grid_search import GridSearchCV, RandomizedSearchCV
1313
from sklearn.feature_selection import RFE, RFECV
14+
from sklearn.ensemble import BaggingClassifier
1415

1516

1617
class DelegatorData(object):
@@ -36,6 +37,9 @@ def __init__(self, name, construct, skip_methods=(),
3637
skip_methods=['transform', 'inverse_transform', 'score']),
3738
DelegatorData('RFECV', RFECV,
3839
skip_methods=['transform', 'inverse_transform', 'score']),
40+
DelegatorData('BaggingClassifier', BaggingClassifier,
41+
skip_methods=['transform', 'inverse_transform', 'score',
42+
'predict_proba', 'predict_log_proba', 'predict'])
3943
]
4044

4145

0 commit comments

Comments
 (0)
0