8000 Move oob_score out of forest, add get_score w/ refactoring. · scikit-learn/scikit-learn@221532a · GitHub
[go: up one dir, main page]

Skip to content

Commit 221532a

Browse files
committed
Move oob_score out of forest, add get_score w/ refactoring.
1 parent 9f17373 commit 221532a

File tree

4 files changed

+163
-288
lines changed

4 files changed

+163
-288
lines changed

sklearn/ensemble/bagging.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ..externals.six import with_metaclass
1818
from ..externals.six.moves import zip
1919
from ..metrics import r2_score, accuracy_score
20+
from ..metrics.scorer import check_scoring, get_score
2021
from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
2122
from ..utils import check_random_state, check_X_y, check_array, column_or_1d
2223
from ..utils.random import sample_without_replacement
@@ -29,6 +30,42 @@
2930
MAX_INT = np.iinfo(np.int32).max
3031

3132

33+
def oob_score(estimator, X, y, scoring, fit_params=None):
34+
"""Compute an estimator's out of bag score.
35+
36+
Parameters
37+
----------
38+
estimator : estimator object implementing 'fit'
39+
The object to use to fit the data.
40+
41+
X : array-like
42+
The data to fit. Can be, for example a list, or an array at least 2d.
43+
44+
y : array-like
45+
The target variable to try to predict.
46+
47+
scoring : string or callable
48+
A string (see model evaluation documentation) or
49+
a scorer callable object / function with signature
50+
``scorer(estimator, X, y)``.
51+
52+
fit_params : dict, optional
53+
Parameters to pass to the fit method of the estimator.
54+
55+
Returns
56+
-------
57+
score : float
58+
Out of bag score of the estimator.
59+
"""
60+
61+
estimator.oob_predict = True
62+
scorer = check_scoring(estimator, scoring=scoring)
63+
fit_params = fit_params if fit_params is not None else {}
64+
estimator.fit(X, y, **fit_params)
65+
return get_score(scorer, y, estimator.oob_prediction_,
66+
estimator.oob_prediction_proba_, estimator.oob_decision_function_)
67+
68+
3269
def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight,
3370
seeds, verbose):
3471
"""Private function used to build a batch of estimators within a job."""

0 commit comments

Comments
 (0)
0