8000 [NOMRG] new warmstart API for GBDTs by NicolasHug · Pull Request #15105 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[NOMRG] new warmstart API for GBDTs #15105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
50 changes: 50 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,56 @@ def _get_tags(self):
collected_tags.update(more_tags)
return collected_tags

def _check_warm_start_with(self, warm_start_with):
# Return True if should warm start
# Also make sure that the parameters are warm-startable, and that
# their new value is either increasing or decreasing.
# If all goes well, use set_params to set the new parameter value.

if not hasattr(self, '_warmstartable_parameters'):
raise ValueError(
"None of the {} parameters can be warm-started."
.format(self.__class__.__name__)
)

if not warm_start_with: # None or empty dict
return False # no warm start needed

def param_increases(param_name):
# util to make sure the param is warm-startable, and to know
# whether the param must be increase or decrease while
# warm-started.
for warmstartable_param in self._warmstartable_parameters:
if param_name == warmstartable_param[1:]:
must_increase = (warmstartable_param[0] == '+')
return must_increase

# No match found, raise error
raise ValueError(
"The {} parameter cannot be warm-started."
.format(param_name)
)

err_msg = (
"The {} class can only be warm-started with {} "
"values of {}. Current value is {}, requesting new value "
"of {}."
)
for param_name, new_value in warm_start_with.items():
current_value = self.get_params()[param_name]
must_increase = param_increases(param_name)
direction = 'increasing' if must_increase else 'decreasing'
if ((must_increase and new_value < current_value) or
(not must_increase and new_value > current_value)):
raise ValueError(
err_msg.format(self.__class__.__name__, direction,
param_name, current_value, new_value)
)

# All went well, setting new parameter value
self.set_params(**{param_name: new_value})
return True # must warm_start


class ClassifierMixin:
"""Mixin class for all classifiers in scikit-learn."""
Expand Down
24 changes: 13 additions & 11 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from abc import ABC, abstractmethod
from functools import partial
import warnings

import numpy as np
from timeit import default_timer as time
Expand All @@ -25,6 +26,8 @@
class BaseHistGradientBoosting(BaseEstimator, ABC):
"""Base class for histogram-based gradient boosting estimators."""

_warmstartable_parameters = ['+max_iter']

@abstractmethod
def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes,
max_depth, min_samples_leaf, l2_regularization, max_bins,
Expand Down Expand Up @@ -80,7 +83,7 @@ def _validate_parameters(self):
raise ValueError('max_bins={} should be no smaller than 2 '
'and no larger than 255.'.format(self.max_bins))

def fit(self, X, y):
def fit(self, X, y, warm_start_with=None):
"""Fit the gradient boosting model.

Parameters
Expand All @@ -106,10 +109,17 @@ def fit(self, X, y):

rng = check_random_state(self.random_state)

# For backward compat (for now)
if self.warm_start:
warnings.warn("warm_start parameter is deprecated",
DeprecationWarning)
warm_start_with = {'max_iter': self.max_iter}

# When warm starting, we want to re-use the same seed that was used
# the first time fit was called (e.g. for subsampling or for the
# train/val split).
if not (self.warm_start and self._is_fitted()):
warm_start = self._check_warm_start_with(warm_start_with)
if not (warm_start and self._is_fitted()):
self._random_seed = rng.randint(np.iinfo(np.uint32).max,
dtype='u8')

Expand Down Expand Up @@ -173,7 +183,7 @@ def fit(self, X, y):
n_samples = X_binned_train.shape[0]

# First time calling fit, or no warm start
if not (self._is_fitted() and self.warm_start):
if not (self._is_fitted() and warm_start):
# Clear random state and score attributes
self._clear_state()

Expand Down Expand Up @@ -255,14 +265,6 @@ def fit(self, X, y):

# warm start: this is not the first time fit was called
else:
# Check that the maximum number of iterations is not smaller
# than the number of iterations from the previous fit
if self.max_iter < self.n_iter_:
raise ValueError(
'max_iter=%d must be larger than or equal to '
'n_iter_=%d when warm_start==True'
% (self.max_iter, self.n_iter_)
)

# Convert array attributes to lists
self.train_score_ = self.train_score_.tolist()
Expand Down
36 changes: 15 additions & 21 deletions sklearn/ensemble/_hist_gradient_boosting/tests/test_warm_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@ def test_max_iter_with_warm_start_validation(GradientBoosting, X, y):
# is smaller than the number of iterations from the previous fit when warm
# start is True.

estimator = GradientBoosting(max_iter=50, warm_start=True)
estimator = GradientBoosting(max_iter=50)
estimator.fit(X, y)
estimator.set_params(max_iter=25)
err_msg = ('max_iter=25 must be larger than or equal to n_iter_=50 '
'when warm_start==True')
err_msg = ('can only be warm-started with increasing values of max_iter')
with pytest.raises(ValueError, match=err_msg):
estimator.fit(X, y)
estimator.fit(X, y, warm_start_with={'max_iter': 25})


@pytest.mark.parametrize('GradientBoosting, X, y', [
Expand All @@ -56,9 +54,9 @@ def test_warm_start_yields_identical_results(GradientBoosting, X, y):

rng = 42
gb_warm_start = GradientBoosting(
n_iter_no_change=100, max_iter=50, random_state=rng, warm_start=True
n_iter_no_change=100, max_iter=50, random_state=rng
)
gb_warm_start.fit(X, y).set_params(max_iter=75).fit(X, y)
gb_warm_start.fit(X, y).fit(X, y, warm_start_with={'max_iter': 75})

gb_no_warm_start = GradientBoosting(
n_iter_no_change=100, max_iter=75, random_state=rng, warm_start=False
Expand All @@ -75,11 +73,10 @@ def test_warm_start_yields_identical_results(GradientBoosting, X, y):
])
def test_warm_start_max_depth(GradientBoosting, X, y):
# Test if possible to fit trees of different depth in ensemble.
gb = GradientBoosting(max_iter=100, min_samples_leaf=1,
warm_start=True, max_depth=2)
gb.fit(X, y)
gb.set_params(max_iter=110, max_depth=3)
gb = GradientBoosting(max_iter=100, min_samples_leaf=1, max_depth=2)
gb.fit(X, y)
gb.set_params(max_depth=3)
gb.fit(X, y, warm_start_with={'max_iter': 110})

# First 100 trees have max_depth == 2
for i in range(100):
Expand All @@ -100,11 +97,11 @@ def test_warm_start_early_stopping(GradientBoosting, X, y):
n_iter_no_change = 5
gb = GradientBoosting(
n_iter_no_change=n_iter_no_change, max_iter=10000,
random_state=42, warm_start=True, tol=1e-3
random_state=42, tol=1e-3
)
gb.fit(X, y)
n_iter_first_fit = gb.n_iter_
gb.fit(X, y)
gb.fit(X, y, warm_start_with={'max_iter': 10000})
n_iter_second_fit = gb.n_iter_
assert n_iter_second_fit - n_iter_first_fit < n_iter_no_change

Expand All @@ -119,8 +116,7 @@ def test_warm_start_equal_n_estimators(GradientBoosting, X, y):
gb_1.fit(X, y)

gb_2 = clone(gb_1)
gb_2.set_params(max_iter=gb_1.max_iter, warm_start=True)
gb_2.fit(X, y)
gb_2.fit(X, y, warm_start_with={'max_iter': gb_1.max_iter})

# Check that both predictors are equal
_assert_predictor_equal(gb_1, gb_2, X)
Expand All @@ -135,11 +131,9 @@ def test_warm_start_clear(GradientBoosting, X, y):
gb_1 = GradientBoosting(n_iter_no_change=5, random_state=42)
gb_1.fit(X, y)

gb_2 = GradientBoosting(n_iter_no_change=5, random_state=42,
warm_start=True)
gb_2 = GradientBoosting(n_iter_no_change=5, random_state=42)
gb_2.fit(X, y) # inits state
gb_2.set_params(warm_start=False)
gb_2.fit(X, y) # clears old state and equals est
gb_2.fit(X, y, warm_start_with=None) # clears old state and equals est

# Check that both predictors have the same train_score_ and
# validation_score_ attributes
Expand Down Expand Up @@ -178,10 +172,10 @@ def _get_rng(rng_type):

random_state = _get_rng(rng_type)
gb_2 = GradientBoosting(n_iter_no_change=5, max_iter=2,
random_state=random_state, warm_start=True)
random_state=random_state)
gb_2.fit(X, y) # inits state
random_seed_2_1 = gb_2._random_seed
gb_2.fit(X, y) # clears old state and equals est
gb_2.fit(X, y, warm_start_with={'max_iter': 2})
random_seed_2_2 = gb_2._random_seed

# Without warm starting, the seeds should be
Expand Down
33 changes: 31 additions & 2 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,19 @@ def __init__(self, steps, memory=None, verbose=False):
self.verbose = verbose
self._validate_steps()

@property
def _warmstartable_parameters(self):
# This property exposes the _warmstartable_parameters attribute, e.g.
# ['+last_step_name__param']
# We consider that only the last step can be warm-started. The first
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make this design choice? I may have asked this question before, but don't remember your logic behind it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A pipeline is either:

  • a sequence of transformers
  • a sequence of transformers + a predictor as the last step

It's safe to assume that transformers are not warm-startable. Maybe in the future one of them will be?? We can worry about that when that happens.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could have a word embedding transformer as a step, which very often is warm started. We may not have that inside sklearn, but the pipeline's API should support it, I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should leave the possibility open, and make the code future-proof. This is one of the reasons why _warmstartable_parameters is a list.

But I don't think we should implement support for that, we have no use-case ATM.

Concretely, supporting warm-start for transformers right now is writing code that isn't used (that would require updating the _fit method that fits all the transformers)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Yeah fair.

# steps are transformers that cannot be warm-started.
out = []
step_name, est = self.steps[-1]
for param in getattr(est, '_warmstartable_parameters', []):
sign, param_name = param[0], param[1:]
out.append(sign + step_name + '__' + param_name)
return out

def get_params(self, deep=True):
"""Get parameters for this estimator.

Expand Down Expand Up @@ -318,7 +331,7 @@ def _fit(self, X, y=None, **fit_params):
return X, {}
return X, fit_params_steps[self.steps[-1][0]]

def fit(self, X, y=None, **fit_params):
def fit(self, X, y=None, warm_start_with=None, **fit_params):
"""Fit the model

Fit all the transforms one after the other and transform the
Expand All @@ -339,6 +352,9 @@ def fit(self, X, y=None, **fit_params):
each parameter name is prefixed such that parameter ``p`` for step
``s`` has key ``s__p``.

warm_start_with : dict, default=None
Indicate which parameter to warm-start, with its new value.

Returns
-------
self : Pipeline
Expand All @@ -348,7 +364,20 @@ def fit(self, X, y=None, **fit_params):
with _print_elapsed_time('Pipeline',
self._log_message(len(self.steps) - 1)):
if self._final_estimator != 'passthrough':
self._final_estimator.fit(Xt, y, **fit_params)
if self._check_warm_start_with(warm_start_with):
# convert {laststep__param: val} into just {param: val}
# Could be a util in the base class
warm_start_with_final = {
name.split('__')[1]: val
for name, val in warm_start_with.items()
}

self._final_estimator.fit(
Xt, y, warm_start_with=warm_start_with_final,
**fit_params)
else:
self._final_estimator.fit(Xt, y, **fit_params)

return self

def fit_transform(self, X, y=None, **fit_params):
Expand Down
21 changes: 21 additions & 0 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
from sklearn.preprocessing import StandardScaler
from sklearn.feature_extraction.text import CountVectorizer

from sklearn.experimental import enable_hist_gradient_boosting # noqa
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.datasets import make_classification


iris = load_iris()

JUNK_FOOD_DOCS = (
Expand Down Expand Up @@ -1195,3 +1200,19 @@ def test_feature_union_warns_with_none():

with pytest.warns(DeprecationWarning, match=msg):
union.fit_transform(X)


def test_warm_start_new_api():
# simple test to illustrate warm-starting on a pipeline

pipe = Pipeline([
('preprocessor', StandardScaler()),
('gbdt', HistGradientBoostingClassifier(max_iter=10))
])

assert pipe._warmstartable_parameters == ['+gbdt__max_iter']

X, y = make_classification()
pipe.fit(X, y)
pipe.fit(X, y, warm_start_with={'gbdt__max_iter': 20})
assert pipe.named_steps['gbdt'].n_iter_ == 20
0