8000 [WIP] GradientBoostingClassifierCV without early stopping by raghavrv · Pull Request #8226 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
8000

[WIP] GradientBoostingClassifierCV without early stopping #8226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter
8000

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions sklearn/ensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from .weight_boosting import AdaBoostClassifier
from .weight_boosting import AdaBoostRegressor
from .gradient_boosting import GradientBoostingClassifier
from .gradient_boosting import GradientBoostingClassifierCV
from .gradient_boosting import GradientBoostingRegressor
# from .gradient_boosting import GradientBoostingRegressorCV
from .voting_classifier import VotingClassifier

from . import bagging
Expand All @@ -29,7 +31,7 @@
"RandomTreesEmbedding", "ExtraTreesClassifier",
"ExtraTreesRegressor", "BaggingClassifier",
"BaggingRegressor", "IsolationForest", "GradientBoostingClassifier",
"GradientBoostingRegressor", "AdaBoostClassifier",
"AdaBoostRegressor", "VotingClassifier",
"bagging", "forest", "gradient_boosting",
"partial_dependence", "weight_boosting"]
"GradientBoostingClassifierCV", "GradientBoostingRegressor",
"GradientBoostingRegressorCV", "AdaBoostClassifier",
"AdaBoostRegressor", "VotingClassifier", "bagging", "forest",
"gradient_boosting", "partial_dependence", "weight_boosting"]
316 changes: 315 additions & 1 deletion sklearn/ensemble/gradient_boosting.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

# Authors: Peter Prettenhofer, Scott White, Gilles Louppe, Emanuele Olivetti,
# Arnaud Joly, Jacob Schreiber
# Arnaud Joly, Jacob Schreiber, Vighnesh Birodkar, Raghav RV
# License: BSD 3 clause

from __future__ import print_function
Expand All @@ -26,11 +26,14 @@
from abc import ABCMeta
from abc import abstractmethod

from .base import clone
from .base import BaseEnsemble
from ..base import BaseEstimator
from ..base import ClassifierMixin
from ..base import RegressorMixin
from ..externals import six
from ..externals.joblib import Parallel
from ..externals.joblib import delayed

from ._gradient_boosting import predict_stages
from ._gradient_boosting import predict_stage
Expand Down Expand Up @@ -63,6 +66,9 @@
from ..utils.multiclass import check_classification_targets
from ..exceptions import NotFittedError

from ..metrics.scorer import check_scoring
from ..model_selection import check_cv


class QuantileEstimator(BaseEstimator):
"""An estimator predicting the alpha-quantile of the training targets."""
Expand Down Expand Up @@ -1866,3 +1872,311 @@ def apply(self, X):
leaves = super(GradientBoostingRegressor, self).apply(X)
leaves = leaves.reshape(X.shape[0], self.estimators_.shape[0])
return leaves


class GradientBoostingClassifierCV(GradientBoostingClassifier):
"""Gradient Boosting for classification with CV to find best n_estimators

Gradient Boosting builds an additive model in a forward stage-wise fashion;
It allows for the optimization of arbitrary differentiable loss functions.
In each stage ``n_classes_`` regression trees are fit on the negative
gradient of the binomial or multinomial deviance loss function. Binary
classification is a special case where only a single regression tree is
induced.

This class allows automatically selecting the best number of boosting
stages (``n_estimators``) based on the cross-validated score computed
for the model after each additional boosting stage. In contrast to grid
search, which will retrain all the previous stages for each possible value
of ``n_estimators``, this class continues upon the previous model (with a
lower number of boosting stage) to save computational time.

Read more in the :ref:`User Guide <gradient_boosting>`.

.. versionadded:: 0.19

Parameters
----------
loss : {'deviance', 'exponential'}, optional (default='deviance')
loss function to be optimized. 'deviance' refers to
deviance (= logistic regression) for classification
with probabilistic outputs. For loss 'exponential' gradient
boosting recovers the AdaBoost algorithm.

learning_rate : float, optional (default=0.1)
learning rate shrinks the contribution of each tree by `learning_rate`.
There is a trade-off between learning_rate and n_estimators.

n_estimators_range : int or array-like of shape (n_cv_stages), (default=100)
The range of boosting stages to search through.

If given as an int, the range of values ``[1, n_estimators_range]`` is
searched for the best number of boosting stages.

This parameter can be a list, in which case the different values are
sorted and the stages are incrementally chosen and tested by
cross-validation. The one giving the best prediction score
is used.

max_depth : integer, optional (default=3)
maximum depth of the individual regression estimators. The maximum
depth limits the number of nodes in the tree. Tune this parameter
for best performance; the best value depends on the interaction
of the input variables.

criterion : string, optional (default="friedman_mse")
The function to measure the quality of a split. Supported criteria
are "friedman_mse" for the mean squared error with improvement
score by Friedman, "mse" for mean squared error, and "mae" for
the mean absolute error. The default value of "friedman_mse" is
generally the best as it can provide a better approximation in
some cases.

min_samples_split : int, float, optional (default=2)
The minimum number of samples required to split an internal node:

- If int, then consider `min_samples_split` as the minimum number.
- If float, then `min_samples_split` is a percentage and
`ceil(min_samples_split * n_samples)` are the minimum
number of samples for each split.

min_samples_leaf : int, float, optional (default=1)
The minimum number of samples required to be at a leaf node:

- If int, then consider `min_samples_leaf` as the minimum number.
- If float, then `min_samples_leaf` is a percentage and
`ceil(min_samples_leaf * n_samples)` are the minimum
number of samples for each node.

min_weight_fraction_leaf : float, optional (default=0.)
The minimum weighted fraction of the sum total of weights (of all
the input samples) required to be at a leaf node. Samples have
equal weight when sample_weight is not provided.

subsample : float, optional (default=1.0)
The fraction of samples to be used for fitting the individual base
learners. If smaller than 1.0 this results in Stochastic Gradient
Boosting. `subsample` interacts with the parameter `n_estimators`.
Choosing `subsample < 1.0` leads to a reduction of variance
and an increase in bias.

max_features : int, float, string or None, optional (default=None)
The number of features to consider when looking for the best split:

- If int, then consider `max_features` features at each split.
- If float, then `max_features` is a percentage and
`int(max_features * n_features)` features are considered at each
split.
- If "auto", then `max_features=sqrt(n_features)`.
- If "sqrt", then `max_features=sqrt(n_features)`.
- If "log2", then `max_features=log2(n_features)`.
- If None, then `max_features=n_features`.

Choosing `max_features < n_features` leads to a reduction of variance
and an increase in bias.

Note: the search for a split does not stop until at least one
valid partition of the node samples is found, even if it requires to
effectively inspect more than ``max_features`` features.

max_leaf_nodes : int or None, optional (default=None)
Grow trees with ``max_leaf_nodes`` in best-first fashion.
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.

min_impurity_split : float, optional (default=1e-7)
Threshold for early stopping in tree growth. A node will split
if its impurity is above the threshold, otherwise it is a leaf.

init : BaseEstimator, None, optional (default=None)
An estimator object that is used to compute the initial
predictions. ``init`` has to provide ``fit`` and ``predict``.
If None it uses ``loss.init_estimator``.

verbose : int, default: 0
Enable verbose output. If 1 then it prints progress and performance
once in a while (the more trees the lower the frequency). If greater
than 1 then it prints progress and performance for every tree.

random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
If RandomState instance, random_state is the random number generator;
If None, the random number generator is the RandomState instance used
by `np.random`.

presort : bool or 'auto', optional (default='auto')
Whether to presort the data to speed up the finding of best splits in
fitting. Auto mode by default will use presorting on dense data and
default to normal sorting on sparse data. Setting presort to true on
sparse data will raise an error.

n_jobs : integer, optional (default=1)
The number of jobs to run in parallel for the `fit` alone.
If -1, then the number of jobs is set to the number of cores.

NOTE that :class:`GradientBoostingClassifier` does not support
``n_jobs``. This class uses parallelization while cross-validating
across the multiple iterations of train-test splits.

pre_dispatch : int, or string, optional, default "2*n_jobs"
Controls the number of jobs that get dispatched during parallel
execution. Reducing this number can be useful to avoid an
explosion of memory consumption when more jobs get dispatched
than CPUs can process. This parameter can be:

- None, in which case all the jobs are immediately
created and spawned. Use this for lightweight and
fast-running jobs, to avoid delays due to on-demand
spawning of the jobs

- An int, giving the exact number of total jobs that are
spawned

- A string, giving an expression as a function of n_jobs,
as in '2*n_jobs'

Attributes
----------
n_estimators_ : int
The number of boosting stages chosen by cross-validation.

n_estimators_range_ : ndarray
Sorted version of the ``n_estimators_range`` parameter.

feature_importances_ : array, shape = [n_features]
The feature importances (the higher, the more important the feature).

oob_improvement_ : array, shape = [n_estimators]
The improvement in loss (= deviance) on the out-of-bag samples
relative to the previous iteration.
``oob_improvement_[0]`` is the improvement in
loss of the first stage over the ``init`` estimator.

train_score_ : array, shape = [n_estimators]
The i-th score ``train_score_[i]`` is the deviance (= loss) of the
model at iteration ``i`` on the in-bag sample.
If ``subsample == 1`` this is the deviance on the training data.

loss_ : LossFunction
The concrete ``LossFunction`` object.

init : BaseEstimator
The estimator that provides the initial predictions.
Set via the ``init`` argument or ``loss.init_estimator``.

estimators_ : ndarray of DecisionTreeRegressor, shape = [n_estimators,
``loss_.K``]
The collection of fitted sub-estimators. ``loss_.K`` is 1 for binary
classification, otherwise n_classes.

See also
--------
GradientBoostingRegressorCV, GradientBoostingClassifier
"""

def __init__(self, n_estimators_range=100, cv=None, scoring=None,
loss='deviance', learning_rate=0.1, subsample=1.0,
criterion='friedman_mse', min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0., max_depth=3,
min_impurity_split=1e-7, init=None, random_state=None,
max_features=None, verbose=0, max_leaf_nodes=None,
presort='auto', n_jobs=None, pre_dispatch='2*n_jobs'):
self.n_estimators_range = n_estimators_range
self.cv = cv
self.scoring = scoring
self.n_jobs = n_jobs
self.pre_dispatch = pre_dispatch

super(GradientBoostingClassifierCV, self).__init__(
loss=loss, learning_rate=learning_rate, n_estimators=1,
criterion=criterion, min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_depth=max_depth, init=init, subsample=subsample,
max_features=max_features,
random_state=random_state, verbose=verbose,
max_leaf_nodes=max_leaf_nodes,
min_impurity_split=min_impurity_split, presort=presort)

def fit(self, X, y, sample_weight=None, sample_groups=None):
"""Find the best n_estimators from the given n_estimators_range and fit

Parameters
----------
X : array-like of shape at least 2D
The data to fit.

y : array-like, optional, default: None
The target variable to try to predict in the case of
supervised learning.

sample_weight : array-like, shape = [n_samples] or None
Sample weights. If None, then samples are equally weighted. Splits
that would create child nodes with net zero or negative weight are
ignored while searching for a split in each node. In the case of
classification, splits are also ignored if they would result in any
single class carrying a negative weight in either child node.

sample_groups : array-like, shape = [n_samples] or None< A3E2 /td>
Sample groups for the cross-validation splitter.
"""
if isinstance(self.n_estimators_range, (numbers.Integral, np.integer)):
n_estimators_range = np.arange(1, self.n_estimators_range + 1)
else:
n_estimators_range = np.array(self.n_estimators_range, dtype=np.int,
copy=False)
if n_estimators_range.ndim != 1:
raise ValueError("The n_estimators_range is expected to be a 1D "
"array of possible n_estimators values. %r "
"is not" % n_estimators_range)

n_estimators_range.sort()
cv = check_cv(self.cv)
scorer = check_scoring(self, scoring=self.scoring)

# We will be parallelizing across the splits
# The same base estimator must be used for each split for all
# boost iterations
base_estimator = GradientBoostingClassifier(
n_estimators=n_estimators_range[0], warm_start=True,
random_state=self.random_state)

parallel = Parallel(n_jobs=self.n_jobs, pre_dispatch=self.pre_dispatch)
out = parallel(delayed(_fit_score_all_stages)(clone(base_estimator),
X, y, sample_weight,
train, test, scorer,
n_estimators_range)
for train, test in cv.split(X, y, groups=sample_groups))

# Store the mean cross-val scores across all splits for each stage
self.scores_ = np.asarray(out).mean(axis=0)
# Store the sorted n_estimators_range
self.n_estimators_range_ = n_estimators_range
self.n_estimators_ = n_estimators_range[self.scores_.argmax()]

# Set the final n_estimators based on the best score and do a full fit
# set_params wont work as n_estimators is not a param of GBCV
self.n_estimators = self.n_estimators_
super(GradientBoostingClassifierCV, self).fit(X, y, sample_weight)
return self


def _fit_score_all_stages(estimator, X, y, sample_weight, train, test,
scorer, n_estimators_range):
"""Fit all stages and compute scores after each stage for given cv split"""
X_train, y_train = X[train], y[train]
X_test, y_test = X[test], y[test]

if sample_weight is not None:
weight_train, weight_test = sample_weight[train], sample_weight[test]
else:
weight_train = weight_test = None

all_stage_scores = np.zeros(n_estimators_range.shape, dtype=np.float64)

for i, n_estimators in enumerate(n_estimators_range):
estimator.set_params(n_estimators=n_estimators)
estimator.fit(X_train, y_train, sample_weight=weight_train)
all_stage_scores[i] = scorer(estimator, X_test, y_test)

return all_stage_scores
0