8000 [MRG] Support arbitrary init estimators for Gradient Boosting by jmschrei · Pull Request #5221 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Support arbitrary init estimators for Gradient Boosting #5221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension 8000

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..base import BaseEstimator
from ..base import ClassifierMixin
from ..base import RegressorMixin
from ..base import is_classifier
from ..utils import check_random_state, check_array, check_X_y, column_or_1d
from ..utils import check_consistent_length, deprecated
from ..utils.extmath import logsumexp
Expand Down Expand Up @@ -952,9 +953,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
sample_weight = column_or_1d(sample_weight, warn=True)

check_consistent_length(X, y, sample_weight)

y = self._validate_y(y)

random_state = check_random_state(self.random_state)
self._check_params()

Expand All @@ -965,8 +964,41 @@ def fit(self, X, y, sample_weight=None, monitor=None):
# fit initial model - FIXME make sample_weight optional
self.init_.fit(X, y, sample_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is weird: the tests checks with pre-trained models and this does not assume that init_ has been trained. It seems inconsistent.

Also if we assume self.init_ to be pre-trained we should set n_targets = self.init_.classes_.shape[0] for a classifier.


# init predictions
y_pred = self.init_.predict(X)
if is_classifier(self.init_):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trailing space

n_classes = np.unique(y).shape[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the local variable n_classes should be renamed n_targets to make is explicit that this can also work for regression.

else:
n_classes = 1

# If the initialization estimator has a predict_proba method,
# either use those, or collapse to a single vector of the
# predicted log odds in the binary classification case. The
# binomial loss used in binary classification problems expects
# the log odds rather than predicted positive class probability.
if hasattr(self.init_, 'predict_proba'):
eps = np.finfo(X.dtype).eps
y_pred = self.init_.predict_proba(X) + eps
if n_classes == 2:
y_pred = np.log(y_pred[:, 1] / y_pred[:, 0])
y_pred = y_pred.reshape(n_samples, 1)

# Otherwise, it can be a naive estimator defined above, in which
# case don't do anything, or a classifier whose estimates will be
# a vector that should be one hot encoded, or a regressor whose
# estimates still need to be reshaped from (n_samples,) to
# (n_samples, 1)
else:
pred = self.init_.predict(X)

if len(pred.shape) < 2:
if is_classifier(self.init_):
raise ValueError("init model must have a "
"'predict_proba' method if a "
"classifier")
else:
y_pred = pred.reshape(n_samples, 1)
else:
y_pred = pred

begin_at_stage = 0
else:
# add more estimators to fitted model
Expand All @@ -981,6 +1013,13 @@ def fit(self, X, y, sample_weight=None, monitor=None):
y_pred = self._decision_function(X)
self._resize_state()

if is_classifier(self.init_):
n_classes = np.unique(y).shape[0]
else:
n_classes = 1

self.n_classes = n_classes

# fit the boosting stages
n_stages = self._fit_stages(X, y, y_pred, sample_weight, random_state,
begin_at_stage, monitor)
Expand Down Expand Up @@ -1077,7 +1116,29 @@ def _init_decision_function(self, X):
if X.shape[1] != self.n_features:
raise ValueError("X.shape[1] should be {0:d}, not {1:d}.".format(
self.n_features, X.shape[1]))
score = self.init_.predict(X).astype(np.float64)
# init predictions

if hasattr(self.init_, 'predict_proba'):
eps = np.finfo(X.dtype).eps
score = self.init_.predict_proba(X) + eps
if self.n_classes == 2:
score = np.log(score[:, 1] / score[:, 0])
score = score.reshape(X.shape[0], 1)
else:
pred = self.init_.predict(X)

if len(pred.shape) < 2:
if is_classifier(self.init_):
raise ValueError("init model must have a "
"'predict_proba' method if a "
"classifier")
else:
score = pred.reshape(X.shape[0], 1)
else:
score = pred

score = score.astype(np.float64)

return score

def _decision_function(self, X):
Expand Down
89 changes: 87 additions & 2 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@

from sklearn import datasets
from sklearn.base import clone
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.cross_validation import train_test_split
from sklearn.ensemble import ExtraTreesClassifier, ExtraTreesRegressor
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from sklearn.ensemble.gradient_boosting import ZeroEstimator
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
from sklearn.svm import SVC, SVR
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils import check_random_state, tosequence
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_greater_equal
from sklearn.utils.testing import assert_less
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_true
Expand Down Expand Up @@ -1022,3 +1028,82 @@ def test_non_uniform_weights_toy_edge_case_clf():
gb.fit(X, y, sample_weight=sample_weight)
assert_array_equal(gb.predict([[1, 0]]), [1])

def test_classification_w_init():
# Test that gradient boosting a previously learned model will improve
# the performance of that model.
iris = datasets.load_digits()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1,
random_state=0)

for clf in [ExtraTreesClassifier(random_state=0, n_estimators=3),
RandomForestClassifier(random_state=0, n_estimators=3)]:

clf.fit(X_train, y_train)
acc1 = clf.score(X_test, y_test)

clf = GradientBoostingClassifier(random_state=0,
n_estimators=1,
init=clf)
clf.fit(X_train, y_train)
acc2 = clf.score(X_test, y_test)
assert_greater_equal(acc2, acc1)

def test_binary_classification_w_init():
# Test that gradient boosting a previously learned model will improve
# the performance of that model.
iris = datasets.load_digits()
X, y = iris.data, iris.target
X = X[ y < 2 ]
y = y[ y < 2 ]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1,
random_state=0)

for clf in [DecisionTreeClassifier(random_state=0),
ExtraTreesClassifier(random_state=0, n_estimators=3),
RandomForestClassifier(random_state=0, n_estimators=3)]:

clf.fit(X_train, y_train)
acc1 = clf.score(X_test, y_test)

clf = GradientBoostingClassifier(random_state=0,
n_estimators=1,
init=clf)
clf.fit(X_train, y_train)
acc2 = clf.score(X_test, y_test)
assert_greater_equal(acc2, acc1)

Copy link
Member
9E4F

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a similar test for a binary classification problem.


def test_regression_w_init():
# Test that gradient boosting a previously learned model will improve
# the performance of that model.
boston = datasets.load_boston()
X, y = boston.data, boston.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1,
random_state=0)

for clf in [DecisionTreeRegressor(random_state=0),
RandomForestRegressor(random_state=0, n_estimators=3),
ExtraTreesRegressor(random_state=0, n_estimators=3),
SVR(), Ridge()]:

clf.fit(X_train, y_train)
acc1 = clf.score(X_test, y_test)

clf = GradientBoostingRegressor(random_state=0,
n_estimators=1,
init=clf)
clf.fit(X_train, y_train)
acc2 = clf.score(X_test, y_test)
assert_greater_equal(acc2, acc1)

def test_error_on_bad_init():
# Test that an error will be raised when a bad init is passed in.

boston = datasets.load_boston()
X, y = boston.data, boston.target

clf = GradientBoostingClassifier(random_state=0, n_estimators=2,
init=SVC())

assert_raises(ValueError, clf.fit, X, y)
0