8000 ENH add subsample to HGBT by lorentzenchr · Pull Request #28063 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH add subsample to HGBT #28063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
106 changes: 58 additions & 48 deletions 10000 examples/ensemble/plot_gradient_boosting_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
================================

Illustration of the effect of different regularization strategies
for Gradient Boosting. The example is taken from Hastie et al 2009 [1]_.
for Gradient Boosting. The example is taken from Chapter 10.12 of
Hastie et al 2009 [1]_.

The loss function used is binomial deviance. Regularization via
The loss function used is log loss, aka binomial deviance. Regularization via
shrinkage (``learning_rate < 1.0``) improves performance considerably.
In combination with shrinkage, stochastic gradient boosting
(``subsample < 1.0``) can produce more accurate models by reducing the
Expand All @@ -21,14 +22,13 @@

"""

# Author: Peter Prettenhofer <peter.prettenhofer@gmail.com>
#
# License: BSD 3 clause

import matplotlib.pyplot as plt
import numpy as np

from sklearn import datasets, ensemble
from sklearn import datasets
from sklearn.ensemble import GradientBoostingClassifier, HistGradientBoostingClassifier
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split

Expand All @@ -44,48 +44,58 @@
"max_leaf_nodes": 4,
"max_depth": None,
"random_state": 2,
"min_samples_split": 5,
"min_samples_leaf": 2,
}

plt.figure()

for label, color, setting in [
("No shrinkage", "orange", {"learning_rate": 1.0, "subsample": 1.0}),
("learning_rate=0.2", "turquoise", {"learning_rate": 0.2, "subsample": 1.0}),
("subsample=0.5", "blue", {"learning_rate": 1.0, "subsample": 0.5}),
(
"learning_rate=0.2, subsample=0.5",
"gray",
{"learning_rate": 0.2, "subsample": 0.5},
),
(
"learning_rate=0.2, max_features=2",
"magenta",
{"learning_rate": 0.2, "max_features": 2},
),
]:
params = dict(original_params)
params.update(setting)

clf = ensemble.GradientBoostingClassifier(**params)
clf.fit(X_train, y_train)

# compute test set deviance
test_deviance = np.zeros((params["n_estimators"],), dtype=np.float64)

for i, y_proba in enumerate(clf.staged_predict_proba(X_test)):
test_deviance[i] = 2 * log_loss(y_test, y_proba[:, 1])

plt.plot(
(np.arange(test_deviance.shape[0]) + 1)[::5],
test_deviance[::5],
"-",
color=color,
label=label,
)

plt.legend(loc="upper right")
plt.xlabel("Boosting Iterations")
plt.ylabel("Test Set Deviance")

plt.show()
fig, axes = plt.subplots(ncols=2, figsize=(10, 5), sharex=True, sharey=True)

for j, model_class in enumerate(
[GradientBoostingClassifier, HistGradientBoostingClassifier]
):
for label, color, setting in [
("No shrinkage", "orange", {"learning_rate": 1.0, "subsample": 1.0}),
("learning_rate=0.2", "turquoise", {"learning_rate": 0.2, "subsample": 1.0}),
("subsample=0.5", "blue", {"learning_rate": 1.0, "subsample": 0.5}),
(
"learning_rate=0.2, subsample=0.5",
"gray",
{"learning_rate": 0.2, "subsample": 0.5},
),
(
"learning_rate=0.2, max_features=2",
"magenta",
{"learning_rate": 0.2, "max_features": 2},
),
]:
params = dict(original_params)
params.update(setting)
n_iter = params["n_estimators"]
if model_class == HistGradientBoostingClassifier:
params["max_iter"] = params.pop("n_estimators")
if "max_features" in params:
params["max_features"] = float(
params["max_features"] / X_train.shape[1]
)

clf = model_class(**params)
clf.fit(X_train, y_train)

# compute test set deviance
test_loss = np.zeros((n_iter,), dtype=np.float64)

for i, y_proba in enumerate(clf.staged_predict_proba(X_test)):
test_loss[i] = 2 * log_loss(y_test, y_proba[:, 1])

axes[j].plot(
(np.arange(test_loss.shape[0]) + 1)[::5],
test_loss[::5],
"-",
color=color,
label=label,
)

axes[j].set_ylim(None, 2)
axes[j].legend(loc="upper right")
axes[j].set_xlabel("Boosting Iterations")
axes[j].set_ylabel("Test Set Log Loss")
axes[j].set_title(model_class.__name__)
21 changes: 16 additions & 5 deletions sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import numpy as np

from .common import Y_DTYPE
from .common cimport Y_DTYPE_C
from ...utils._typedefs cimport int64_t


def _update_raw_predictions(
Y_DTYPE_C [::1] raw_predictions, # OUT
grower,
n_threads,
sample_idx,
):
"""Update raw_predictions with the predictions of the newest tree.

Expand All @@ -35,7 +37,7 @@ def _update_raw_predictions(
values = np.array([leaf.value for leaf in leaves], dtype=Y_DTYPE)

_update_raw_predictions_helper(raw_predictions, starts, stops, partition,
values, n_threads)
values, n_threads, sample_idx)


cdef inline void _update_raw_predictions_helper(
Expand All @@ -45,14 +47,23 @@ cdef inline void _update_raw_predictions_helper(
const unsigned int [::1] partition,
const Y_DTYPE_C [::1] values,
int n_threads,
const int64_t [::1] sample_idx=None,
):

cdef:
unsigned int position
int leaf_idx
int n_leaves = starts.shape[0]

for leaf_idx in prange(n_leaves, schedule='static', nogil=True,
num_threads=n_threads):
for position in range(starts[leaf_idx], stops[leaf_idx]):
raw_predictions[partition[position]] += values[leaf_idx]
if sample_idx is None:
for leaf_idx in prange(
n_leaves, schedule='static', nogil=True, num_threads=n_threads
):
for position in range(starts[leaf_idx], stops[leaf_idx]):
raw_predictions[partition[position]] += values[leaf_idx]
else:
for leaf_idx in prange(
n_leaves, schedule='static', nogil=True, num_threads=n_threads
):
for position in range(starts[leaf_idx], stops[leaf_idx]):
raw_predictions[sample_idx[partition[position]]] += values[leaf_idx]
90 changes: 86 additions & 4 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ...utils import check_random_state, compute_sample_weight, is_scalar_nan, resample
from ...utils._openmp_helpers import _openmp_effective_n_threads
from ...utils._param_validation import Hidden, Interval, RealNotInt, StrOptions
from ...utils.fixes import parse_version
from ...utils.multiclass import check_classification_targets
from ...utils.validation import (
_check_monotonic_cst,
Expand Down Expand Up @@ -148,6 +149,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC):
"min_samples_leaf": [Interval(Integral, 1, None, closed="left")],
"l2_regularization": [Interval(Real, 0, None, closed="left")],
"max_features": [Interval(RealNotInt, 0, 1, closed="right")],
"subsample": [Interval(Real, 0.0, 1.0, closed="right")],
"monotonic_cst": ["array-like", dict, None],
"interaction_cst": [
list,
Expand Down Expand Up @@ -188,6 +190,7 @@ def __init__(
min_samples_leaf,
l2_regularization,
max_features,
subsample,
max_bins,
categorical_features,
monotonic_cst,
Expand All @@ -209,6 +212,7 @@ def __init__(
self.min_samples_leaf = min_samples_leaf
self.l2_regularization = l2_regularization
self.max_features = max_features
self.subsample = subsample
self.max_bins = max_bins
self.monotonic_cst = monotonic_cst
self.interaction_cst = interaction_cst
Expand Down Expand Up @@ -579,6 +583,30 @@ def fit(self, X, y, sample_weight=None):
self._random_seed = rng.randint(np.iinfo(np.uint32).max, dtype="u8")
feature_subsample_seed = rng.randint(np.iinfo(np.uint32).max, dtype="u8")
self._feature_subsample_rng = np.random.default_rng(feature_subsample_seed)
# TODO: Remove this condition, once numpy 1.25 is the minimum version.
if parse_version(np.__version__) >= parse_version("1.25"):
self._bagging_subsample_rng = self._feature_subsample_rng.spawn(1)[0]
else:
# See numpy Generator.spawn(self, int n_children) and
# numpy BitGenerator.spawn

def spawn_generator(generator, n_children):
return [
type(generator)(g)
for g in spawn_bit_generator(
generator._bit_generator, n_children
)
]

def spawn_bit_generator(_bit_generator, n_children):
return [
type(_bit_generator)(seed=s)
for s in _bit_generator._seed_seq.spawn(n_children)
]

self._bagging_subsample_rng = spawn_generator(
self._feature_subsample_rng, 1
)[0]

self._validate_parameters()
monotonic_cst = _check_monotonic_cst(self, self.monotonic_cst)
Expand Down Expand Up @@ -838,6 +866,19 @@ def fit(self, X, y, sample_weight=None):

begin_at_stage = self.n_iter_

# Out of bag settings
do_oob = self.subsample < 1.0
if do_oob:
# Note that setting sample_weight to zero for the corresponding samples
# would result in false "count" statistics of the histograms. Therefore,
# we make take copys (fancy indexed numpy arrays) for the subsampling.
n_inbag = max(1, int(self.subsample * n_samples))
sample_mask = np.zeros((n_samples,), dtype=bool)
sample_mask[:n_inbag] = True
else:
sample_mask = slice(None)
sample_mask_idx = None

# initialize gradients and hessians (empty arrays).
# shape = (n_samples, n_trees_per_iteration).
gradient, hessian = self._loss.init_gradient_and_hessian(
Expand Down Expand Up @@ -884,12 +925,28 @@ def fit(self, X, y, sample_weight=None):
g_view = gradient
h_view = hessian

# Do out of bag if required
if do_oob:
self._bagging_subsample_rng.shuffle(sample_mask)
sample_mask_idx = np.flatnonzero(sample_mask)
X_binned_grow = np.asfortranarray(X_binned_train[sample_mask])
g_grow = np.asfortranarray(g_view[sample_mask])
if self._loss.constant_hessian:
h_grow = h_view
else:
h_grow = np.asfortranarray(h_view[sample_mask])

else:
X_binned_grow = X_binned_train
g_grow = g_view
h_grow = h_view

# Build `n_trees_per_iteration` trees.
for k in range(self.n_trees_per_iteration_):
grower = TreeGrower(
X_binned=X_binned_train,
gradients=g_view[:, k],
hessians=h_view[:, k],
X_binned=X_binned_grow,
gradients=g_grow[:, k],
hessians=h_grow[:, k],
n_bins=n_bins,
n_bins_non_missing=self._bin_mapper.n_bins_non_missing_,
has_missing_values=has_missing_values,
Expand Down Expand Up @@ -928,7 +985,12 @@ def fit(self, X, y, sample_weight=None):
# Update raw_predictions with the predictions of the newly
# created tree.
tic_pred = time()
_update_raw_predictions(raw_predictions[:, k], grower, n_threads)
_update_raw_predictions(
raw_predictions[:, k],
grower,
n_threads,
sample_idx=sample_mask_idx,
)
toc_pred = time()
acc_prediction_time += toc_pred - tic_pred

Expand Down Expand Up @@ -1490,6 +1552,14 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):

.. versionadded:: 1.4

subsample : float, default=1.0
The fraction of randomly chosen samples to be used for fitting the individual
tree(s) in each boosting iteration. If smaller than 1.0 this results in
Stochastic Gradient Boosting or bagging.
Values must be in the range `(0.0, 1.0]`.

.. versionadded:: 1.5

max_bins : int, default=255
The maximum number of bins to use for non-missing values. Before
training, each feature of the input array `X` is binned into
Expand Down Expand Up @@ -1698,6 +1768,7 @@ def __init__(
min_samples_leaf=20,
l2_regularization=0.0,
max_features=1.0,
subsample=1.0,
max_bins=255,
categorical_features="warn",
monotonic_cst=None,
Expand All @@ -1720,6 +1791,7 @@ def __init__(
min_samples_leaf=min_samples_leaf,
l2_regularization=l2_regularization,
max_features=max_features,
subsample=subsample,
max_bins=max_bins,
monotonic_cst=monotonic_cst,
interaction_cst=interaction_cst,
Expand Down Expand Up @@ -1866,6 +1938,14 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):

.. versionadded:: 1.4

subsample : float, default=1.0
The fraction of randomly chosen samples to be used for fitting the individual
tree(s) in each boosting iteration. If smaller than 1.0 this results in
Stochastic Gradient Boosting or bagging.
Values must be in the range `(0.0, 1.0]`.

.. versionadded:: 1.5

max_bins : int, default=255
The maximum number of bins to use for non-missing values. Before
training, each feature of the input array `X` is binned into
Expand Down Expand Up @@ -2076,6 +2156,7 @@ def __init__(
min_samples_leaf=20,
l2_regularization=0.0,
max_features=1.0,
subsample=1.0,
max_bins=255,
categorical_features="warn",
monotonic_cst=None,
Expand All @@ -2099,6 +2180,7 @@ def __init__(
min_samples_leaf=min_samples_leaf,
l2_regularization=l2_regularization,
max_features=max_features,
subsample=subsample,
max_bins=max_bins,
categorical_features=categorical_features,
monotonic_cst=monotonic_cst,
Expand Down
Loading
0